ter.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, List, Optional, Union
  16. import torch
  17. from torch import Tensor, tensor
  18. from torchmetrics.functional.text.ter import _ter_compute, _ter_update, _TercomTokenizer
  19. from torchmetrics.metric import Metric
  20. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  21. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  22. if not _MATPLOTLIB_AVAILABLE:
  23. __doctest_skip__ = ["TranslationEditRate.plot"]
  24. class TranslationEditRate(Metric):
  25. """Calculate Translation edit rate (`TER`_) of machine translated text with one or more references.
  26. This implementation follows the one from `SacreBleu_ter`_, which is a
  27. near-exact reimplementation of the Tercom algorithm, produces identical results on all "sane" outputs.
  28. As input to ``forward`` and ``update`` the metric accepts the following input:
  29. - ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus
  30. - ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus
  31. As output of ``forward`` and ``compute`` the metric returns the following output:
  32. - ``ter`` (:class:`~torch.Tensor`): if ``return_sentence_level_score=True`` return a corpus-level translation
  33. edit rate with a list of sentence-level translation_edit_rate, else return a corpus-level translation edit rate
  34. Args:
  35. normalize: An indication whether a general tokenization to be applied.
  36. no_punctuation: An indication whteher a punctuation to be removed from the sentences.
  37. lowercase: An indication whether to enable case-insensitivity.
  38. asian_support: An indication whether asian characters to be processed.
  39. return_sentence_level_score: An indication whether a sentence-level TER to be returned.
  40. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  41. Example:
  42. >>> from torchmetrics.text import TranslationEditRate
  43. >>> preds = ['the cat is on the mat']
  44. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  45. >>> ter = TranslationEditRate()
  46. >>> ter(preds, target)
  47. tensor(0.1538)
  48. """
  49. is_differentiable: bool = False
  50. higher_is_better: bool = False
  51. full_state_update: bool = False
  52. plot_lower_bound: float = 0.0
  53. plot_upper_bound: float = 1.0
  54. total_num_edits: Tensor
  55. total_tgt_len: Tensor
  56. sentence_ter: Optional[List[Tensor]] = None
  57. def __init__(
  58. self,
  59. normalize: bool = False,
  60. no_punctuation: bool = False,
  61. lowercase: bool = True,
  62. asian_support: bool = False,
  63. return_sentence_level_score: bool = False,
  64. **kwargs: Any,
  65. ) -> None:
  66. super().__init__(**kwargs)
  67. if not isinstance(normalize, bool):
  68. raise ValueError(f"Expected argument `normalize` to be of type boolean but got {normalize}.")
  69. if not isinstance(no_punctuation, bool):
  70. raise ValueError(f"Expected argument `no_punctuation` to be of type boolean but got {no_punctuation}.")
  71. if not isinstance(lowercase, bool):
  72. raise ValueError(f"Expected argument `lowercase` to be of type boolean but got {lowercase}.")
  73. if not isinstance(asian_support, bool):
  74. raise ValueError(f"Expected argument `asian_support` to be of type boolean but got {asian_support}.")
  75. self.tokenizer = _TercomTokenizer(normalize, no_punctuation, lowercase, asian_support)
  76. self.return_sentence_level_score = return_sentence_level_score
  77. self.add_state("total_num_edits", tensor(0.0), dist_reduce_fx="sum")
  78. self.add_state("total_tgt_len", tensor(0.0), dist_reduce_fx="sum")
  79. if self.return_sentence_level_score:
  80. self.add_state("sentence_ter", [], dist_reduce_fx="cat")
  81. def update(self, preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]]) -> None:
  82. """Update state with predictions and targets."""
  83. self.total_num_edits, self.total_tgt_len, self.sentence_ter = _ter_update(
  84. preds,
  85. target,
  86. self.tokenizer,
  87. self.total_num_edits,
  88. self.total_tgt_len,
  89. self.sentence_ter,
  90. )
  91. def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]:
  92. """Calculate the translate error rate (TER)."""
  93. ter = _ter_compute(self.total_num_edits, self.total_tgt_len)
  94. if self.sentence_ter is not None:
  95. return ter, torch.cat(self.sentence_ter)
  96. return ter
  97. def plot(
  98. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  99. ) -> _PLOT_OUT_TYPE:
  100. """Plot a single or multiple values from the metric.
  101. Args:
  102. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  103. If no value is provided, will automatically call `metric.compute` and plot that result.
  104. ax: An matplotlib axis object. If provided will add plot to that axis
  105. Returns:
  106. Figure and Axes object
  107. Raises:
  108. ModuleNotFoundError:
  109. If `matplotlib` is not installed
  110. .. plot::
  111. :scale: 75
  112. >>> # Example plotting a single value
  113. >>> from torchmetrics.text import TranslationEditRate
  114. >>> metric = TranslationEditRate()
  115. >>> preds = ['the cat is on the mat']
  116. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  117. >>> metric.update(preds, target)
  118. >>> fig_, ax_ = metric.plot()
  119. .. plot::
  120. :scale: 75
  121. >>> # Example plotting multiple values
  122. >>> from torchmetrics.text import TranslationEditRate
  123. >>> metric = TranslationEditRate()
  124. >>> preds = ['the cat is on the mat']
  125. >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
  126. >>> values = [ ]
  127. >>> for _ in range(10):
  128. ... values.append(metric(preds, target))
  129. >>> fig_, ax_ = metric.plot(values)
  130. """
  131. return self._plot(val, ax)