eed.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # referenced from
  15. # Library Name: torchtext
  16. # Authors: torchtext authors
  17. # Date: 2021-12-07
  18. # Link:
  19. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  20. # The RWTH Extended Edit Distance (EED) License
  21. # Copyright (c) 2019, RWTH.
  22. # All rights reserved.
  23. # This license is derived from the Q Public License v1.0 and the Qt Non-Commercial License v1.0 which are both Copyright
  24. # by Trolltech AS, Norway. The aim of this license is to lay down the conditions enabling you to use, modify and
  25. # circulate the SOFTWARE, use of third-party application programs based on the Software and publication of results
  26. # obtained through the use of modified and unmodified versions of the SOFTWARE. However, RWTH remain the authors of the
  27. # SOFTWARE and so retain property rights and the use of all ancillary rights. The SOFTWARE is defined as all successive
  28. # versions of EED software and their documentation that have been developed by RWTH.
  29. #
  30. # When you access and use the SOFTWARE, you are presumed to be aware of and to have accepted all the rights and
  31. # obligations of the present license:
  32. #
  33. # 1. You are granted the non-exclusive rights set forth in this license provided you agree to and comply with any all
  34. # conditions in this license. Whole or partial distribution of the Software, or software items that link with the
  35. # Software, in any form signifies acceptance of this license for non-commercial use only.
  36. # 2. You may copy and distribute the Software in unmodified form provided that the entire package, including - but not
  37. # restricted to - copyright, trademark notices and disclaimers, as released by the initial developer of the
  38. # Software, is distributed.
  39. # 3. You may make modifications to the Software and distribute your modifications, in a form that is separate from the
  40. # Software, such as patches. The following restrictions apply to modifications:
  41. # a. Modifications must not alter or remove any copyright notices in the Software.
  42. # b When modifications to the Software are released under this license, a non-exclusive royalty-free right is
  43. # granted to the initial developer of the Software to distribute your modification in future versions of the
  44. # Software provided such versions remain available under these terms in addition to any other license(s) of the
  45. # initial developer.
  46. # 4. You may distribute machine-executable forms of the Software or machine-executable forms of modified versions of
  47. # the Software, provided that you meet these restrictions:
  48. # a. You must include this license document in the distribution.
  49. # b. You must ensure that all recipients of the machine-executable forms are also able to receive the complete
  50. # machine-readable source code to the distributed Software, including all modifications, without any charge
  51. # beyond the costs of data transfer, and place prominent notices in the distribution explaining this.
  52. # c. You must ensure that all modifications included in the machine-executable forms are available under the terms
  53. # of this license.
  54. # 5. You may use the original or modified versions of the Software to compile, link and run application programs
  55. # legally developed by you or by others.
  56. # 6. You may develop application programs, reusable components and other software items, in a non-commercial setting,
  57. # that link with the original or modified versions of the Software. These items, when distributed, are subject to
  58. # the following requirements:
  59. # a. You must ensure that all recipients of machine-executable forms of these items are also able to receive and use
  60. # the complete machine-readable source code to the items without any charge beyond the costs of data transfer.
  61. # b. You must explicitly license all recipients of your items to use and re-distribute original and modified
  62. # versions of the items in both machine-executable and source code forms. The recipients must be able to do so
  63. # without any charges whatsoever, and they must be able to re-distribute to anyone they choose.
  64. # c. If an application program gives you access to functionality of the Software for development of application
  65. # programs, reusable components or other software components (e.g. an application that is a scripting wrapper),
  66. # usage of the application program is considered to be usage of the Software and is thus bound by this license.
  67. # d. If the items are not available to the general public, and the initial developer of the Software requests a copy
  68. # of the items, then you must supply one.
  69. # 7. Users must cite the authors of the Software upon publication of results obtained through the use of original or
  70. # modified versions of the Software by referring to the following publication:
  71. # P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”, submitted to WMT
  72. # 2019.
  73. # 8. In no event shall the initial developers or copyright holders be liable for any damages whatsoever, including -
  74. # but not restricted to - lost revenue or profits or other direct, indirect, special, incidental or consequential
  75. # damages, even if they have been advised of the possibility of such damages, except to the extent invariable law,
  76. # if any, provides otherwise.
  77. # 9. You assume all risks concerning the quality or the effects of the SOFTWARE and its use. If the SOFTWARE is
  78. # defective, you will bear the costs of all required services, corrections or repairs.
  79. # 10. This license has the binding value of a contract.
  80. # 11. The present license and its effects are subject to German law and the competent German Courts.
  81. #
  82. # The Software and this license document are provided "AS IS" with NO EXPLICIT OR IMPLICIT WARRANTY OF ANY KIND,
  83. # INCLUDING WARRANTY OF DESIGN, ADAPTION, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
  84. import re
  85. import unicodedata
  86. from collections.abc import Sequence
  87. from math import inf
  88. from typing import List, Optional, Union
  89. from torch import Tensor, stack, tensor
  90. from typing_extensions import Literal
  91. from torchmetrics.functional.text.helper import _validate_inputs
  92. def _distance_between_words(preds_word: str, target_word: str) -> int:
  93. """Distance measure used for substitutions/identity operation.
  94. Code adapted from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/EED.py.
  95. Args:
  96. preds_word: hypothesis word string
  97. target_word: reference word string
  98. Return:
  99. 0 for match, 1 for no match
  100. """
  101. return int(preds_word != target_word)
  102. def _eed_function(
  103. hyp: str,
  104. ref: str,
  105. alpha: float = 2.0,
  106. rho: float = 0.3,
  107. deletion: float = 0.2,
  108. insertion: float = 1.0,
  109. ) -> float:
  110. """Compute extended edit distance score for two lists of strings: hyp and ref.
  111. Code adapted from: https://github.com/rwth-i6/ExtendedEditDistance/blob/master/EED.py.
  112. Args:
  113. hyp: A hypothesis string
  114. ref: A reference string
  115. alpha: optimal jump penalty, penalty for jumps between characters
  116. rho: coverage cost, penalty for repetition of characters
  117. deletion: penalty for deletion of character
  118. insertion: penalty for insertion or substitution of character
  119. Return:
  120. Extended edit distance score as float
  121. """
  122. number_of_visits = [-1] * (len(hyp) + 1)
  123. # row[i] stores cost of cheapest path from (0,0) to (i,l) in CDER alignment grid.
  124. row = [1.0] * (len(hyp) + 1)
  125. row[0] = 0.0 # CDER initialisation 0,0 = 0.0, rest 1.0
  126. next_row = [inf] * (len(hyp) + 1)
  127. for w in range(1, len(ref) + 1):
  128. for i in range(len(hyp) + 1):
  129. if i > 0:
  130. next_row[i] = min(
  131. next_row[i - 1] + deletion,
  132. row[i - 1] + _distance_between_words(hyp[i - 1], ref[w - 1]),
  133. row[i] + insertion,
  134. )
  135. else:
  136. next_row[i] = row[i] + 1.0
  137. min_index = next_row.index(min(next_row))
  138. number_of_visits[min_index] += 1
  139. # Long Jumps
  140. if ref[w - 1] == " ":
  141. jump = alpha + next_row[min_index]
  142. next_row = [min(x, jump) for x in next_row]
  143. row = next_row
  144. next_row = [inf] * (len(hyp) + 1)
  145. coverage = rho * sum(x if x >= 0 else 1 for x in number_of_visits)
  146. return min(1, (row[-1] + coverage) / (float(len(ref)) + coverage))
  147. def _preprocess_en(sentence: str) -> str:
  148. """Preprocess english sentences.
  149. Copied from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/util.py.
  150. Raises:
  151. ValueError: If input sentence is not of a type `str`.
  152. """
  153. if not isinstance(sentence, str):
  154. raise ValueError(f"Only strings allowed during preprocessing step, found {type(sentence)} instead")
  155. sentence = sentence.rstrip() # trailing space, tab, or newline
  156. # Add space before interpunctions
  157. rules_interpunction = [
  158. (".", " ."),
  159. ("!", " !"),
  160. ("?", " ?"),
  161. (",", " ,"),
  162. ]
  163. for pattern, replacement in rules_interpunction:
  164. sentence = sentence.replace(pattern, replacement)
  165. rules_re = [
  166. (r"\s+", r" "), # get rid of extra spaces
  167. (r"(\d) ([.,]) (\d)", r"\1\2\3"), # 0 . 1 -> 0.1
  168. (r"(Dr|Jr|Prof|Rev|Gen|Mr|Mt|Mrs|Ms) .", r"\1."), # Mr . -> Mr.
  169. ]
  170. for pattern, replacement in rules_re:
  171. sentence = re.sub(pattern, replacement, sentence)
  172. # Add space between abbreviations
  173. rules_interpunction = [
  174. ("e . g .", "e.g."),
  175. ("i . e .", "i.e."),
  176. ("U . S .", "U.S."),
  177. ]
  178. for pattern, replacement in rules_interpunction:
  179. sentence = sentence.replace(pattern, replacement)
  180. # add space to beginning and end of string
  181. return " " + sentence + " "
  182. def _preprocess_ja(sentence: str) -> str:
  183. """Preprocess japanese sentences.
  184. Copy from https://github.com/rwth-i6/ExtendedEditDistance/blob/master/util.py.
  185. Raises:
  186. ValueError: If input sentence is not of a type `str`.
  187. """
  188. if not isinstance(sentence, str):
  189. raise ValueError(f"Only strings allowed during preprocessing step, found {type(sentence)} instead")
  190. sentence = sentence.rstrip() # trailing space, tab, newline
  191. # characters which look identical actually are identical
  192. return unicodedata.normalize("NFKC", sentence)
  193. def _eed_compute(sentence_level_scores: List[Tensor]) -> Tensor:
  194. """Reduction for extended edit distance.
  195. Args:
  196. sentence_level_scores: list of sentence-level scores as floats
  197. Return:
  198. average of scores as a tensor
  199. """
  200. if len(sentence_level_scores) == 0:
  201. return tensor(0.0)
  202. return sum(sentence_level_scores) / tensor(len(sentence_level_scores))
  203. def _preprocess_sentences(
  204. preds: Union[str, Sequence[str]],
  205. target: Sequence[Union[str, Sequence[str]]],
  206. language: Literal["en", "ja"],
  207. ) -> tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]:
  208. """Preprocess strings according to language requirements.
  209. Args:
  210. preds: An iterable of hypothesis corpus.
  211. target: An iterable of iterables of reference corpus.
  212. language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en
  213. Return:
  214. Tuple of lists that contain the cleaned strings for target and preds
  215. Raises:
  216. ValueError: If a different language than ``'en'`` or ``'ja'`` is used
  217. ValueError: If length of target not equal to length of preds
  218. ValueError: If objects in reference and hypothesis corpus are not strings
  219. """
  220. # sanity checks
  221. target, preds = _validate_inputs(hypothesis_corpus=preds, ref_corpus=target)
  222. # preprocess string
  223. if language == "en":
  224. preprocess_function = _preprocess_en
  225. elif language == "ja":
  226. preprocess_function = _preprocess_ja
  227. else:
  228. raise ValueError(f"Expected argument `language` to either be `en` or `ja` but got {language}")
  229. preds = [preprocess_function(pred) for pred in preds]
  230. target = [[preprocess_function(ref) for ref in reference] for reference in target]
  231. return preds, target
  232. def _compute_sentence_statistics(
  233. preds_word: str,
  234. target_words: Union[str, Sequence[str]],
  235. alpha: float = 2.0,
  236. rho: float = 0.3,
  237. deletion: float = 0.2,
  238. insertion: float = 1.0,
  239. ) -> Tensor:
  240. """Compute scores for ExtendedEditDistance.
  241. Args:
  242. target_words: An iterable of reference words
  243. preds_word: A hypothesis word
  244. alpha: An optimal jump penalty, penalty for jumps between characters
  245. rho: coverage cost, penalty for repetition of characters
  246. deletion: penalty for deletion of character
  247. insertion: penalty for insertion or substitution of character
  248. Return:
  249. best_score: best (lowest) sentence-level score as a Tensor
  250. """
  251. best_score = inf
  252. for reference in target_words:
  253. score = _eed_function(preds_word, reference, alpha, rho, deletion, insertion)
  254. if score < best_score:
  255. best_score = score
  256. return tensor(best_score)
  257. def _eed_update(
  258. preds: Union[str, Sequence[str]],
  259. target: Sequence[Union[str, Sequence[str]]],
  260. language: Literal["en", "ja"] = "en",
  261. alpha: float = 2.0,
  262. rho: float = 0.3,
  263. deletion: float = 0.2,
  264. insertion: float = 1.0,
  265. sentence_eed: Optional[List[Tensor]] = None,
  266. ) -> List[Tensor]:
  267. """Compute scores for ExtendedEditDistance.
  268. Args:
  269. preds: An iterable of hypothesis corpus
  270. target: An iterable of iterables of reference corpus
  271. language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en
  272. alpha: optimal jump penalty, penalty for jumps between characters
  273. rho: coverage cost, penalty for repetition of characters
  274. deletion: penalty for deletion of character
  275. insertion: penalty for insertion or substitution of character
  276. sentence_eed: list of sentence-level scores
  277. Return:
  278. individual sentence scores as a list of Tensors
  279. """
  280. preds, target = _preprocess_sentences(preds, target, language)
  281. if sentence_eed is None:
  282. sentence_eed = []
  283. # return tensor(0.0) if target or preds is empty
  284. if 0 in (len(preds), len(target[0])):
  285. return sentence_eed
  286. for hypothesis, target_words in zip(preds, target):
  287. score = _compute_sentence_statistics(hypothesis, target_words, alpha, rho, deletion, insertion)
  288. sentence_eed.append(score)
  289. return sentence_eed
  290. def extended_edit_distance(
  291. preds: Union[str, Sequence[str]],
  292. target: Sequence[Union[str, Sequence[str]]],
  293. language: Literal["en", "ja"] = "en",
  294. return_sentence_level_score: bool = False,
  295. alpha: float = 2.0,
  296. rho: float = 0.3,
  297. deletion: float = 0.2,
  298. insertion: float = 1.0,
  299. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  300. """Compute extended edit distance score (`ExtendedEditDistance`_) [1] for strings or list of strings.
  301. The metric utilises the Levenshtein distance and extends it by adding a jump operation.
  302. Args:
  303. preds: An iterable of hypothesis corpus.
  304. target: An iterable of iterables of reference corpus.
  305. language: Language used in sentences. Only supports English (en) and Japanese (ja) for now. Defaults to en
  306. return_sentence_level_score: An indication of whether sentence-level EED score is to be returned.
  307. alpha: optimal jump penalty, penalty for jumps between characters
  308. rho: coverage cost, penalty for repetition of characters
  309. deletion: penalty for deletion of character
  310. insertion: penalty for insertion or substitution of character
  311. Return:
  312. Extended edit distance score as a tensor
  313. Example:
  314. >>> from torchmetrics.functional.text import extended_edit_distance
  315. >>> preds = ["this is the prediction", "here is an other sample"]
  316. >>> target = ["this is the reference", "here is another one"]
  317. >>> extended_edit_distance(preds=preds, target=target)
  318. tensor(0.3078)
  319. References:
  320. [1] P. Stanchev, W. Wang, and H. Ney, “EED: Extended Edit Distance Measure for Machine Translation”,
  321. submitted to WMT 2019. `ExtendedEditDistance`_
  322. """
  323. # input validation for parameters
  324. for param_name, param in zip(["alpha", "rho", "deletion", "insertion"], [alpha, rho, deletion, insertion]):
  325. if not isinstance(param, float) or (isinstance(param, float) and param < 0):
  326. raise ValueError(f"Parameter `{param_name}` is expected to be a non-negative float.")
  327. sentence_level_scores = _eed_update(preds, target, language, alpha, rho, deletion, insertion)
  328. average = _eed_compute(sentence_level_scores)
  329. if return_sentence_level_score:
  330. return average, stack(sentence_level_scores)
  331. return average