helper.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  15. # Copyright 2020 Memsource
  16. #
  17. # Licensed under the Apache License, Version 2.0 (the "License");
  18. # you may not use this file except in compliance with the License.
  19. # You may obtain a copy of the License at
  20. #
  21. # http://www.apache.org/licenses/LICENSE-2.0
  22. #
  23. # Unless required by applicable law or agreed to in writing, software
  24. # distributed under the License is distributed on an "AS IS" BASIS,
  25. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  26. # See the License for the specific language governing permissions and
  27. # limitations under the License.
  28. import math
  29. from collections.abc import Sequence
  30. from enum import Enum, unique
  31. from typing import Union
  32. # Tercom-inspired limits
  33. _BEAM_WIDTH = 25
  34. # Sacrebleu-inspired limits
  35. _MAX_CACHE_SIZE = 10000
  36. _INT_INFINITY = int(1e16)
  37. @unique
  38. class _EditOperations(str, Enum):
  39. """Enumerations for the Levenhstein edit operations."""
  40. OP_INSERT = "insert"
  41. OP_DELETE = "delete"
  42. OP_SUBSTITUTE = "substitute"
  43. OP_NOTHING = "nothing"
  44. OP_UNDEFINED = "undefined"
  45. class _LevenshteinEditDistance:
  46. """A convenience class for calculating the Levenshtein edit distance.
  47. Class will cache some intermediate values to hasten the calculation. The implementation follows the implementation
  48. from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/lib_ter.py,
  49. where the most of this implementation is adapted and copied from.
  50. Args:
  51. reference_tokens: list of reference tokens
  52. op_insert: cost of insertion operation
  53. op_delete: cost of deletion operation
  54. op_substitute: cost of substitution operation
  55. """
  56. def __init__(
  57. self, reference_tokens: list[str], op_insert: int = 1, op_delete: int = 1, op_substitute: int = 1
  58. ) -> None:
  59. self.reference_tokens = reference_tokens
  60. self.reference_len = len(reference_tokens)
  61. self.cache: dict[str, tuple[int, str]] = {}
  62. self.cache_size = 0
  63. self.op_insert = op_insert
  64. self.op_delete = op_delete
  65. self.op_substitute = op_substitute
  66. self.op_nothing = 0
  67. self.op_undefined = _INT_INFINITY
  68. def __call__(self, prediction_tokens: list[str]) -> tuple[int, tuple[_EditOperations, ...]]:
  69. """Calculate edit distance between self._words_ref and the hypothesis. Uses cache to skip some computations.
  70. Args:
  71. prediction_tokens: A tokenized predicted sentence.
  72. Return:
  73. A tuple of a calculated edit distance and a trace of executed operations.
  74. """
  75. # Use cached edit distance for already computed words
  76. start_position, cached_edit_distance = self._find_cache(prediction_tokens)
  77. # Calculate the rest of the edit distance matrix
  78. edit_distance_int, edit_distance, trace = self._levenshtein_edit_distance(
  79. prediction_tokens, start_position, cached_edit_distance
  80. )
  81. # Update our cache with the newly calculated rows
  82. self._add_cache(prediction_tokens, edit_distance)
  83. return edit_distance_int, trace
  84. def _levenshtein_edit_distance(
  85. self,
  86. prediction_tokens: list[str],
  87. prediction_start: int,
  88. cache: list[list[tuple[int, _EditOperations]]],
  89. ) -> tuple[int, list[list[tuple[int, _EditOperations]]], tuple[_EditOperations, ...]]:
  90. """Dynamic programming algorithm to compute the Levenhstein edit distance.
  91. Args:
  92. prediction_tokens: A tokenized predicted sentence.
  93. prediction_start: An index where a predicted sentence to be considered from.
  94. cache: A cached Levenshtein edit distance.
  95. Returns:
  96. Edit distance between the predicted sentence and the reference sentence
  97. """
  98. prediction_len = len(prediction_tokens)
  99. empty_rows: list[list[tuple[int, _EditOperations]]] = [
  100. list(self._get_empty_row(self.reference_len)) for _ in range(prediction_len - prediction_start)
  101. ]
  102. edit_distance: list[list[tuple[int, _EditOperations]]] = cache + empty_rows
  103. length_ratio = self.reference_len / prediction_len if prediction_tokens else 1.0
  104. # Ensure to not end up with zero overlaip with previous role
  105. beam_width = math.ceil(length_ratio / 2 + _BEAM_WIDTH) if length_ratio / 2 > _BEAM_WIDTH else _BEAM_WIDTH
  106. # Calculate the Levenshtein distance
  107. for i in range(prediction_start + 1, prediction_len + 1):
  108. pseudo_diag = math.floor(i * length_ratio)
  109. min_j = max(0, pseudo_diag - beam_width)
  110. max_j = (
  111. self.reference_len + 1 if i == prediction_len else min(self.reference_len + 1, pseudo_diag + beam_width)
  112. )
  113. for j in range(min_j, max_j):
  114. if j == 0:
  115. edit_distance[i][j] = (
  116. edit_distance[i - 1][j][0] + self.op_delete,
  117. _EditOperations.OP_DELETE,
  118. )
  119. else:
  120. if prediction_tokens[i - 1] == self.reference_tokens[j - 1]:
  121. cost_substitute = self.op_nothing
  122. operation_substitute = _EditOperations.OP_NOTHING
  123. else:
  124. cost_substitute = self.op_substitute
  125. operation_substitute = _EditOperations.OP_SUBSTITUTE
  126. # Tercom prefers no-op/sub, then insertion, then deletion. But since we flip the trace and compute
  127. # the alignment from the inverse, we need to swap order of insertion and deletion in the
  128. # preference.
  129. # Copied from: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/ter.py.
  130. operations = (
  131. (edit_distance[i - 1][j - 1][0] + cost_substitute, operation_substitute),
  132. (edit_distance[i - 1][j][0] + self.op_delete, _EditOperations.OP_DELETE),
  133. (edit_distance[i][j - 1][0] + self.op_insert, _EditOperations.OP_INSERT),
  134. )
  135. for operation_cost, operation_name in operations:
  136. if edit_distance[i][j][0] > operation_cost:
  137. edit_distance[i][j] = operation_cost, operation_name
  138. trace = self._get_trace(prediction_len, edit_distance)
  139. return edit_distance[-1][-1][0], edit_distance[len(cache) :], trace
  140. def _get_trace(
  141. self, prediction_len: int, edit_distance: list[list[tuple[int, _EditOperations]]]
  142. ) -> tuple[_EditOperations, ...]:
  143. """Get a trace of executed operations from the edit distance matrix.
  144. Args:
  145. prediction_len: A length of a tokenized predicted sentence.
  146. edit_distance:
  147. A matrix of the Levenshtedin edit distance. The element part of the matrix is a tuple of an edit
  148. operation cost and an edit operation itself.
  149. Return:
  150. A trace of executed operations returned as a tuple of `_EDIT_OPERATIONS` enumerates.
  151. Raises:
  152. ValueError:
  153. If an unknown operation has been applied.
  154. """
  155. trace: tuple[_EditOperations, ...] = ()
  156. i = prediction_len
  157. j = self.reference_len
  158. while i > 0 or j > 0:
  159. operation = edit_distance[i][j][1]
  160. trace = (operation, *trace)
  161. if operation in (_EditOperations.OP_SUBSTITUTE, _EditOperations.OP_NOTHING):
  162. i -= 1
  163. j -= 1
  164. elif operation == _EditOperations.OP_INSERT:
  165. j -= 1
  166. elif operation == _EditOperations.OP_DELETE:
  167. i -= 1
  168. else:
  169. raise ValueError(f"Unknown operation {operation!r}")
  170. return trace
  171. def _add_cache(self, prediction_tokens: list[str], edit_distance: list[list[tuple[int, _EditOperations]]]) -> None:
  172. """Add newly computed rows to cache.
  173. Since edit distance is only calculated on the hypothesis suffix that was not in cache, the number of rows in
  174. `edit_distance` matrx may be shorter than hypothesis length. In that case we skip over these initial words.
  175. Args:
  176. prediction_tokens: A tokenized predicted sentence.
  177. edit_distance:
  178. A matrix of the Levenshtedin edit distance. The element part of the matrix is a tuple of an edit
  179. operation cost and an edit operation itself.
  180. """
  181. if self.cache_size >= _MAX_CACHE_SIZE:
  182. return
  183. node = self.cache
  184. # how many initial words to skip
  185. skip_num = len(prediction_tokens) - len(edit_distance)
  186. # Jump through the cache to the current position
  187. for i in range(skip_num):
  188. node = node[prediction_tokens[i]][0] # type: ignore
  189. # Update cache with newly computed rows
  190. for word, row in zip(prediction_tokens[skip_num:], edit_distance):
  191. if word not in node:
  192. node[word] = ({}, tuple(row)) # type: ignore
  193. self.cache_size += 1
  194. value = node[word]
  195. node = value[0] # type: ignore
  196. def _find_cache(self, prediction_tokens: list[str]) -> tuple[int, list[list[tuple[int, _EditOperations]]]]:
  197. """Find the already calculated rows of the Levenshtein edit distance metric.
  198. Args:
  199. prediction_tokens: A tokenized predicted sentence.
  200. Return:
  201. A tuple of a start hypothesis position and `edit_distance` matrix.
  202. prediction_start: An index where a predicted sentence to be considered from.
  203. edit_distance:
  204. A matrix of the cached Levenshtedin edit distance. The element part of the matrix is a tuple of an edit
  205. operation cost and an edit operation itself.
  206. """
  207. node = self.cache
  208. start_position = 0
  209. edit_distance: list[list[tuple[int, _EditOperations]]] = [self._get_initial_row(self.reference_len)]
  210. for word in prediction_tokens:
  211. if word in node:
  212. start_position += 1
  213. node, row = node[word] # type: ignore
  214. edit_distance.append(row) # type: ignore
  215. else:
  216. break
  217. return start_position, edit_distance
  218. def _get_empty_row(self, length: int) -> list[tuple[int, _EditOperations]]:
  219. """Precomputed empty matrix row for Levenhstein edit distance.
  220. Args:
  221. length: A length of a tokenized sentence.
  222. Return:
  223. A list of tuples containing infinite edit operation costs and yet undefined edit operations.
  224. """
  225. return [(int(self.op_undefined), _EditOperations.OP_UNDEFINED)] * (length + 1)
  226. def _get_initial_row(self, length: int) -> list[tuple[int, _EditOperations]]:
  227. """First row corresponds to insertion operations of the reference, so 1 edit operation per reference word.
  228. Args:
  229. length: A length of a tokenized sentence.
  230. Return:
  231. A list of tuples containing edit operation costs of insert and insert edit operations.
  232. """
  233. return [(i * self.op_insert, _EditOperations.OP_INSERT) for i in range(length + 1)]
  234. def _validate_inputs(
  235. ref_corpus: Union[Sequence[str], Sequence[Sequence[str]]],
  236. hypothesis_corpus: Union[str, Sequence[str]],
  237. ) -> tuple[Sequence[Sequence[str]], Sequence[str]]:
  238. """Check and update (if needed) the format of reference and hypothesis corpora for various text evaluation metrics.
  239. Args:
  240. ref_corpus: An iterable of iterables of reference corpus.
  241. hypothesis_corpus: An iterable of hypothesis corpus.
  242. Return:
  243. ref_corpus: An iterable of iterables of reference corpus.
  244. hypothesis_corpus: An iterable of hypothesis corpus.
  245. Raises:
  246. ValueError:
  247. If length of `ref_corpus` and `hypothesis_corpus` differs.
  248. """
  249. if isinstance(hypothesis_corpus, str):
  250. hypothesis_corpus = [hypothesis_corpus]
  251. # Ensure reference corpus is properly of a type Sequence[Sequence[str]]
  252. if all(isinstance(ref, str) for ref in ref_corpus):
  253. ref_corpus = [ref_corpus] if len(hypothesis_corpus) == 1 else [[ref] for ref in ref_corpus] # type: ignore
  254. if hypothesis_corpus and all(ref for ref in ref_corpus) and len(ref_corpus) != len(hypothesis_corpus):
  255. raise ValueError(f"Corpus has different size {len(ref_corpus)} != {len(hypothesis_corpus)}")
  256. return ref_corpus, hypothesis_corpus
  257. def _edit_distance(prediction_tokens: list[str], reference_tokens: list[str]) -> int:
  258. """Dynamic programming algorithm to compute the edit distance.
  259. Args:
  260. prediction_tokens: A tokenized predicted sentence
  261. reference_tokens: A tokenized reference sentence
  262. Returns:
  263. Edit distance between the predicted sentence and the reference sentence
  264. """
  265. dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)]
  266. for i in range(len(prediction_tokens) + 1):
  267. dp[i][0] = i
  268. for j in range(len(reference_tokens) + 1):
  269. dp[0][j] = j
  270. for i in range(1, len(prediction_tokens) + 1):
  271. for j in range(1, len(reference_tokens) + 1):
  272. if prediction_tokens[i - 1] == reference_tokens[j - 1]:
  273. dp[i][j] = dp[i - 1][j - 1]
  274. else:
  275. dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
  276. return dp[-1][-1]
  277. def _flip_trace(trace: tuple[_EditOperations, ...]) -> tuple[_EditOperations, ...]:
  278. """Flip the trace of edit operations.
  279. Instead of rewriting a->b, get a recipe for rewriting b->a. Simply flips insertions and deletions.
  280. Args:
  281. trace: A tuple of edit operations.
  282. Return:
  283. inverted_trace:
  284. A tuple of inverted edit operations.
  285. """
  286. _flip_operations: dict[_EditOperations, _EditOperations] = {
  287. _EditOperations.OP_INSERT: _EditOperations.OP_DELETE,
  288. _EditOperations.OP_DELETE: _EditOperations.OP_INSERT,
  289. }
  290. def _replace_operation_or_retain(
  291. operation: _EditOperations, _flip_operations: dict[_EditOperations, _EditOperations]
  292. ) -> _EditOperations:
  293. if operation in _flip_operations:
  294. return _flip_operations.get(operation) # type: ignore
  295. return operation
  296. return tuple(_replace_operation_or_retain(operation, _flip_operations) for operation in trace)
  297. def _trace_to_alignment(trace: tuple[_EditOperations, ...]) -> tuple[dict[int, int], list[int], list[int]]:
  298. """Transform trace of edit operations into an alignment of the sequences.
  299. Args:
  300. trace: A trace of edit operations as a tuple of `_EDIT_OPERATIONS` enumerates.
  301. Return:
  302. alignments: A dictionary mapping aligned positions between a reference and a hypothesis.
  303. reference_errors: A list of error positions in a reference.
  304. hypothesis_errors: A list of error positions in a hypothesis.
  305. Raises:
  306. ValueError:
  307. If an unknown operation is
  308. """
  309. reference_position = hypothesis_position = -1
  310. reference_errors: list[int] = []
  311. hypothesis_errors: list[int] = []
  312. alignments: dict[int, int] = {}
  313. # we are rewriting a into b
  314. for operation in trace:
  315. if operation == _EditOperations.OP_NOTHING:
  316. hypothesis_position += 1
  317. reference_position += 1
  318. alignments[reference_position] = hypothesis_position
  319. reference_errors.append(0)
  320. hypothesis_errors.append(0)
  321. elif operation == _EditOperations.OP_SUBSTITUTE:
  322. hypothesis_position += 1
  323. reference_position += 1
  324. alignments[reference_position] = hypothesis_position
  325. reference_errors.append(1)
  326. hypothesis_errors.append(1)
  327. elif operation == _EditOperations.OP_INSERT:
  328. hypothesis_position += 1
  329. hypothesis_errors.append(1)
  330. elif operation == _EditOperations.OP_DELETE:
  331. reference_position += 1
  332. alignments[reference_position] = hypothesis_position
  333. reference_errors.append(1)
  334. else:
  335. raise ValueError(f"Unknown operation {operation!r}.")
  336. return alignments, reference_errors, hypothesis_errors