tokenization_nougat.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. # Copyright 2023 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Tokenizer class for Nougat.
  16. """
  17. import re
  18. from functools import partial
  19. from multiprocessing import Pool
  20. import numpy as np
  21. from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
  22. from tokenizers.models import BPE
  23. from ...tokenization_utils_tokenizers import TokenizersBackend
  24. from ...utils import is_levenshtein_available, is_nltk_available, logging, requires_backends
  25. if is_levenshtein_available():
  26. from Levenshtein import ratio
  27. if is_nltk_available():
  28. import nltk
  29. logger = logging.get_logger(__name__)
  30. VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
  31. def markdown_compatible(text: str) -> str:
  32. """
  33. Make text compatible with Markdown formatting.
  34. This function makes various text formatting adjustments to make it compatible with Markdown.
  35. Args:
  36. text (`str`):
  37. The input text to be made Markdown-compatible.
  38. Returns:
  39. `str`: The Markdown-compatible text.
  40. """
  41. # equation tag
  42. # Replace lines that start with a pattern like (decimal) \[some text\] with \[[some text] \tag{decimal}\].
  43. text = re.sub(r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", text, flags=re.MULTILINE)
  44. # Replace lines that start with a pattern like \[some text\] (decimal) with \[[some text] \tag{decimal}\].
  45. text = re.sub(r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", text, flags=re.MULTILINE)
  46. # Replace lines that start with a pattern like \[some text\] (digits) \[another text\] with \[[some text] \tag{digits}\] [another text].
  47. text = re.sub(
  48. r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$",
  49. r"\[\1 \\tag{\2}\] \3",
  50. text,
  51. flags=re.MULTILINE,
  52. )
  53. # multi line
  54. text = text.replace(r"\. ", ". ")
  55. # bold formatting
  56. text = text.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{")
  57. text = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", text)
  58. # Reformat urls (http, ftp and https only) to markdown [url](url) clickable format
  59. text = re.sub(
  60. r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))",
  61. r"[\1](\1)",
  62. text,
  63. )
  64. # algorithms
  65. text = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", text, flags=re.DOTALL)
  66. return text
  67. def normalize_list_like_lines(generation):
  68. """
  69. Normalize lines in the given text that resemble list items. The function looks for lines that start optionally with
  70. '-' or '*', possibly followed by Roman numerals or digits indicating nesting levels. The function reformats such
  71. lines to make them more structured.
  72. Args:
  73. generation (str): The input text containing lines that need to be normalized.
  74. Returns:
  75. str: The input text with the list-like lines normalized.
  76. Note:
  77. The function uses regular expressions to identify and reformat the list-like lines. The patterns capture
  78. optional bullet points, nesting levels indicated by numerals, and the actual list item content. The
  79. normalization adjusts the bullet point style and nesting levels based on the captured patterns.
  80. """
  81. lines = generation.split("\n")
  82. output_lines = []
  83. for line_no, line in enumerate(lines):
  84. match = re.search(r". ([-*]) ", line)
  85. if not match or line[0] not in ("-", "*"):
  86. output_lines.append(line)
  87. continue # Doesn't fit the pattern we want, no changes
  88. delim = match.group(1) + " "
  89. splits = line.split(delim)[1:]
  90. replacement = ""
  91. delim1 = line[0] + " "
  92. for i, item in enumerate(splits):
  93. level = 0
  94. potential_numeral, _, rest = item.strip().partition(" ")
  95. if not rest:
  96. continue
  97. # Infer current nesting level based on detected numbering
  98. if re.match(r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.IGNORECASE | re.MULTILINE):
  99. level = potential_numeral.count(".")
  100. replacement += (
  101. ("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
  102. )
  103. if line_no == len(lines) - 1: # If this is the last line in the generation
  104. replacement += "\n" # Add an empty line to the end of the generation
  105. output_lines.append(replacement)
  106. return "\n".join(output_lines)
  107. def find_next_punctuation(text: str, start_idx=0):
  108. """
  109. Find the index of the next punctuation mark.
  110. Args:
  111. text (`str`):
  112. String to examine
  113. start_idx (`int`, *optional*)
  114. Index where to start
  115. """
  116. for i in range(start_idx, len(text)):
  117. if text[i] in [".", "?", "!", "\n"]:
  118. return i
  119. return None
  120. def truncate_repetitions(text: str, min_len: int = 30) -> str:
  121. """
  122. Attempt to truncate repeating segments in the input string.
  123. This function looks for the longest repeating substring at the end of the input string and truncates it to appear
  124. only once. To be considered for removal, repetitions need to be continuous.
  125. Args:
  126. text (`str`):
  127. The input raw prediction to be truncated.
  128. min_len (int):
  129. The minimum length of the repeating segment.
  130. Returns:
  131. `str`: The input string with repeated segments truncated.
  132. """
  133. text_lower = text.lower()
  134. text_length = len(text_lower)
  135. if text_length < 2 * min_len:
  136. return text
  137. # try to find a length at which the tail is repeating
  138. max_repetition_length = None
  139. for repetition_length in range(min_len, int(text_length / 2)):
  140. # check if there is a repetition at the end
  141. same = True
  142. for i in range(0, repetition_length):
  143. if text_lower[text_length - repetition_length - i - 1] != text_lower[text_length - i - 1]:
  144. same = False
  145. break
  146. if same:
  147. max_repetition_length = repetition_length
  148. if max_repetition_length is None:
  149. return text
  150. lcs = text_lower[-max_repetition_length:]
  151. # remove all but the last repetition
  152. substituted_text = text
  153. substituted_text_lower = text_lower
  154. while substituted_text_lower.endswith(lcs):
  155. substituted_text = substituted_text[:-max_repetition_length]
  156. substituted_text_lower = substituted_text_lower[:-max_repetition_length]
  157. # this is the tail with the repetitions
  158. repeating_tail = text_lower[len(substituted_text_lower) :]
  159. # add until next punctuation and make sure last sentence is not repeating
  160. substituted_text_lower_out = substituted_text_lower
  161. while True:
  162. sentence_end = find_next_punctuation(text_lower, len(substituted_text_lower_out))
  163. sentence_start = find_next_punctuation(text_lower[::-1], len(substituted_text_lower_out))
  164. if sentence_end and sentence_start:
  165. sentence = text_lower[sentence_start:sentence_end]
  166. substituted_text_lower_out = text_lower[: sentence_end + 1]
  167. if sentence in repeating_tail:
  168. break
  169. else:
  170. break
  171. text_out = text[: len(substituted_text_lower_out)]
  172. return text_out
  173. def remove_numbers(lines):
  174. def _clean(s):
  175. return re.sub(r"(?:[\d_]|\*\*)", "", s).strip()
  176. if isinstance(lines, str):
  177. return _clean(lines)
  178. out = []
  179. for l in lines:
  180. out.append(_clean(l))
  181. return out
  182. def get_slices(lines, clean_lines):
  183. """
  184. Get slices of text based on specific criteria within the lines.
  185. This function identifies and returns slices of text from the input lines based on certain conditions.
  186. These conditions were chosen by the Nougat authors:
  187. - The slice is less than 200 characters long.
  188. - The slice is more than 3 characters long.
  189. - The slice does not start with "[MISSING_PAGE".
  190. - The slice is either the same as the next slice or the ratio of the two in terms of Levenshtein distance is
  191. greater than 0.9.
  192. Args:
  193. lines (`list[str]`):
  194. The list of lines containing the text.
  195. clean_lines (`list[str]`):
  196. A cleaned version of the text (without numbers).
  197. Returns:
  198. `list[tuple]`: A list of tuples representing the start and end indices of text slices.
  199. """
  200. indices = np.zeros(len(lines))
  201. for i in range(len(lines) - 1):
  202. j = i + 1
  203. while not clean_lines[j] and j < len(lines) - 1:
  204. j += 1
  205. if (
  206. len(clean_lines[i]) < 200
  207. and len(clean_lines[i]) > 3
  208. and len(clean_lines[j]) < 200
  209. and len(clean_lines[j]) > 3
  210. and not clean_lines[i].startswith("[MISSING_PAGE")
  211. and (clean_lines[i] == clean_lines[j] or ratio(clean_lines[i], clean_lines[j]) > 0.9)
  212. ):
  213. indices[i:j] = 1
  214. ids = np.where(indices)[0]
  215. slices = []
  216. if len(ids) == 0:
  217. return slices
  218. j0 = 0
  219. for j, x in enumerate(np.diff(ids) > 3):
  220. if x:
  221. slices.append((ids[j0], ids[j] + 2))
  222. j0 = j + 1
  223. slices.append((ids[j0], ids[-1] + 2))
  224. return [sli for sli in slices if sli[1] - sli[0] > 15]
  225. def remove_slice_from_lines(lines, clean_text, slice) -> str:
  226. """
  227. Remove a slice of text from the lines based on specific criteria.
  228. This function identifies a slice of text within the lines and removes it based on certain conditions.
  229. Args:
  230. lines (list of str): The list of lines containing the text.
  231. clean_text (list of str): A cleaned version of the text (without numbers).
  232. slice (tuple): A tuple representing the start and end indices of the slice to be removed.
  233. Returns:
  234. str: The removed slice of text as a single string.
  235. """
  236. base = clean_text[slice[0]]
  237. section = list(slice)
  238. check_start_flag = False
  239. # backwards pass, at most 5 lines
  240. for line_idx in range(max(0, slice[0] - 1), max(0, slice[0] - 5), -1):
  241. if not lines[line_idx]:
  242. continue
  243. if lines[line_idx] == "## References":
  244. section[0] = line_idx
  245. break
  246. elif ratio(base, remove_numbers(lines[line_idx])) < 0.9:
  247. section[0] = line_idx + 1
  248. potential_ref = remove_numbers(lines[max(0, line_idx - 1)].partition("* [")[-1])
  249. if len(potential_ref) >= 0.75 * len(base) and ratio(base, potential_ref) < 0.9:
  250. section[0] = line_idx
  251. check_start_flag = True
  252. break
  253. # forward pass, at most 5 lines
  254. for line_idx in range(min(len(lines), slice[1]), min(len(lines), slice[1] + 5)):
  255. if ratio(base, remove_numbers(lines[line_idx])) < 0.9:
  256. section[1] = line_idx
  257. break
  258. if len(lines) <= section[1]:
  259. section[1] = len(lines) - 1
  260. to_delete = "\n".join(lines[section[0] : section[1] + 1])
  261. # cut off next page content
  262. itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]])
  263. while True:
  264. try:
  265. (ia, a) = next(itera)
  266. while a.isnumeric():
  267. (ia, a) = next(itera)
  268. (ib, b) = next(iterb)
  269. while b.isnumeric():
  270. (ib, b) = next(iterb)
  271. if a != b:
  272. break
  273. except StopIteration:
  274. break
  275. if check_start_flag and "* [" in to_delete:
  276. to_delete = "* [" + to_delete.partition("* [")[-1]
  277. try:
  278. delta = len(lines[section[1]]) - ib - 1
  279. if delta > 0:
  280. to_delete = to_delete[:-delta]
  281. except UnboundLocalError:
  282. pass
  283. return to_delete.strip()
  284. class NougatTokenizer(TokenizersBackend):
  285. """
  286. Tokenizer for Nougat (backed by HuggingFace tokenizers library).
  287. This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
  288. refer to this superclass for more information regarding those methods. This class mainly adds Nougat-specific
  289. methods for postprocessing the generated text.
  290. Args:
  291. vocab_file (`str`, *optional*):
  292. Path to the vocabulary file.
  293. merges_file (`str`, *optional*):
  294. Path to the merges file.
  295. tokenizer_file (`str`, *optional*):
  296. [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
  297. contains everything needed to load the tokenizer.
  298. clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
  299. Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
  300. spaces.
  301. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  302. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  303. token instead.
  304. bos_token (`str`, *optional*, defaults to `"<s>"`):
  305. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  306. eos_token (`str`, *optional*, defaults to `"</s>"`):
  307. The end of sequence token.
  308. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  309. The token used for padding, for example when batching sequences of different lengths.
  310. vocab (`str`, `dict` or `list`, *optional*):
  311. Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
  312. merges (`str` or `list`, *optional*):
  313. Custom merges list. If not provided, merges are loaded from merges_file.
  314. """
  315. vocab_files_names = VOCAB_FILES_NAMES
  316. model_input_names = ["input_ids", "attention_mask"]
  317. model = BPE
  318. def __init__(
  319. self,
  320. errors: str = "replace",
  321. unk_token: str = "<unk>",
  322. bos_token: str = "<s>",
  323. eos_token: str = "</s>",
  324. pad_token: str = "<pad>",
  325. vocab: str | dict | list | None = None,
  326. merges: str | list | None = None,
  327. **kwargs,
  328. ):
  329. self._vocab = (
  330. vocab
  331. if vocab is not None
  332. else {
  333. str(bos_token): 0,
  334. str(pad_token): 1,
  335. str(eos_token): 2,
  336. str(unk_token): 3,
  337. "[START_REF]": 4,
  338. }
  339. )
  340. self._merges = merges or []
  341. self._tokenizer = Tokenizer(
  342. BPE(
  343. vocab=self._vocab,
  344. merges=self._merges,
  345. dropout=None,
  346. continuing_subword_prefix="",
  347. end_of_word_suffix="",
  348. fuse_unk=False,
  349. )
  350. )
  351. self._tokenizer.normalizer = normalizers.NFKC()
  352. self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  353. [
  354. pre_tokenizers.Split(pattern="SPL1T-TH1S-Pl3A5E", behavior="removed", invert=False),
  355. pre_tokenizers.Digits(individual_digits=True),
  356. pre_tokenizers.Split(
  357. pattern=r"[\(\)\[\]\{\}]|([!\"#\$%\&'\*\+,\-\./:;<=>\?\\\^_`\|\~])\1*",
  358. behavior="isolated",
  359. invert=False,
  360. ),
  361. pre_tokenizers.Split(pattern="\n", behavior="isolated", invert=False),
  362. pre_tokenizers.ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True),
  363. ]
  364. )
  365. self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True)
  366. super().__init__(
  367. errors=errors,
  368. unk_token=unk_token,
  369. bos_token=bos_token,
  370. eos_token=eos_token,
  371. pad_token=pad_token,
  372. **kwargs,
  373. )
  374. self._tokenizer.post_processor = processors.TemplateProcessing(
  375. single=f"{bos_token}:0 $A:0 {eos_token}:0",
  376. pair="$A:0 $B:1",
  377. special_tokens=[
  378. (str(eos_token), self.eos_token_id),
  379. (str(bos_token), self.bos_token_id),
  380. ],
  381. )
  382. # Enable truncation and padding
  383. self._tokenizer.enable_truncation(max_length=4096)
  384. self._tokenizer.enable_padding(length=4096, pad_id=self.pad_token_id, pad_token=str(pad_token))
  385. def remove_hallucinated_references(self, text: str) -> str:
  386. """
  387. Remove hallucinated or missing references from the text.
  388. This function identifies and removes references that are marked as missing or hallucinated from the input text.
  389. Args:
  390. text (`str`):
  391. The input text containing references.
  392. Returns:
  393. `str`: The text with hallucinated references removed.
  394. """
  395. lines = text.split("\n")
  396. if len(lines) == 0:
  397. return ""
  398. clean_lines = remove_numbers(lines)
  399. slices = get_slices(lines, clean_lines)
  400. to_delete = []
  401. for slice in slices:
  402. to_delete.append(remove_slice_from_lines(lines, clean_lines, slice))
  403. for to_delete in reversed(to_delete):
  404. text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n")
  405. text = re.sub(
  406. r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]",
  407. "\n\n[MISSING_PAGE_POST\\1]",
  408. text,
  409. )
  410. return text
  411. def correct_tables(self, generation: str) -> str:
  412. """
  413. Takes a generated string and fixes tables/tabulars to make them match the markdown format needed.
  414. Args:
  415. generation (str): The generated text to be postprocessed.
  416. Returns:
  417. str: The postprocessed text.
  418. Example:
  419. ```python
  420. correct_tables("\\begin{table} \\begin{tabular}{l l} & \\ \\end{tabular} \\end{table}")
  421. "\\begin{table}\n\\begin{tabular}{l l} & \\ \\end{tabular}\n\\end{table}"
  422. ```
  423. """
  424. # remove obvious wrong tables
  425. for l in generation.split("\n"):
  426. if l.count("\\begin{tabular}") > 15 or l.count("\\multicolumn") > 60 or l.count("&") > 400:
  427. generation = generation.replace(l, "")
  428. # whitespace corrections
  429. generation = generation.replace("\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}")
  430. generation = generation.replace("\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}")
  431. generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab")
  432. generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.MULTILINE)
  433. # Remove left-aligned empty LaTeX tabular blocks.
  434. generation = generation.replace(r"\begin{tabular}{l l} & \\ \end{tabular}", "")
  435. # Remove tabulars with just 2 newline characters.
  436. generation = generation.replace("\\begin{tabular}{}\n\n\\end{tabular}", "")
  437. return generation
  438. def post_process_single(self, generation: str, fix_markdown: bool = True) -> str:
  439. """
  440. Postprocess a single generated text. Regular expressions used here are taken directly from the Nougat article
  441. authors. These expressions are commented for clarity and tested end-to-end in most cases.
  442. Args:
  443. generation (str): The generated text to be postprocessed.
  444. fix_markdown (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
  445. Returns:
  446. str: The postprocessed text.
  447. """
  448. generation = re.sub(
  449. r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation
  450. ) # too long section titles probably are none
  451. generation = generation.strip()
  452. # Remove LaTeX left margin tag
  453. generation = generation.replace("\n* [leftmargin=*]\n", "\n")
  454. # Remove lines with markdown headings starting with #, with numerals,
  455. # and possibly roman numerals with trailing spaces and newlines
  456. generation = re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.MULTILINE)
  457. # most likely hallucinated titles
  458. lines = generation.split("\n")
  459. if lines[-1].startswith("#") and lines[-1].lstrip("#").startswith(" ") and len(lines) > 1:
  460. logger.info("Likely hallucinated title at the end of the page: " + lines[-1])
  461. generation = "\n".join(lines[:-1])
  462. # obvious repetition detection
  463. generation = truncate_repetitions(generation)
  464. # Reference corrections
  465. generation = self.remove_hallucinated_references(generation)
  466. # Remove lines starting with asterisks and numbers like "*[1]" and followed by capital letters and periods (ie too long references)
  467. generation = re.sub(r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.MULTILINE)
  468. # Remove empty brackets after a reference number in brackets. *[12][]ABC will become *[12]ABC
  469. generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.MULTILINE)
  470. # Remove single characters before or after 2 new lines
  471. generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation)
  472. # pmc math artifact correction
  473. generation = re.sub(
  474. r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])",
  475. r"\1\(\2_{\3}\)\4",
  476. generation,
  477. )
  478. generation = re.sub(r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation)
  479. # footnote mistakes
  480. generation = re.sub(
  481. r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))",
  482. r"\1 \2",
  483. generation,
  484. )
  485. # TODO Come up with footnote formatting inside a table
  486. generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation)
  487. # itemize post processing
  488. generation = normalize_list_like_lines(generation)
  489. if generation.endswith((".", "}")):
  490. generation += "\n\n"
  491. if re.match(r"[A-Z0-9,;:]$", generation):
  492. # add space in case it there is a comma or word ending
  493. generation += " "
  494. elif generation.startswith(("#", "**", "\\begin")):
  495. generation = "\n\n" + generation
  496. elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")):
  497. generation = generation + "\n\n"
  498. else:
  499. try:
  500. last_word = generation.split(" ")[-1]
  501. if last_word in nltk.corpus.words.words():
  502. generation += " "
  503. except LookupError:
  504. # add space just in case. Will split words but better than concatenating them
  505. generation += " "
  506. # table corrections
  507. generation = self.correct_tables(generation)
  508. # Remove optional, empty square brackets after begin{array}
  509. generation = generation.replace("\\begin{array}[]{", "\\begin{array}{")
  510. # Remove empty or malformed LaTeX tabular blocks with 2 or more columns specified, with spaces and ampersands.
  511. generation = re.sub(
  512. r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}",
  513. "",
  514. generation,
  515. )
  516. # Remove lines containing "S.A.B." one or more times. Was included in Nougat's code.
  517. generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation)
  518. # Remove markdown-style headers that are incomplete or empty on multiple lines.
  519. generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.MULTILINE)
  520. # Remove lines with just one period.
  521. generation = re.sub(r"^\.\s*$", "", generation, flags=re.MULTILINE)
  522. # Replace instances of three or more newlines with just two newlines.
  523. generation = re.sub(r"\n{3,}", "\n\n", generation)
  524. if fix_markdown:
  525. return markdown_compatible(generation)
  526. else:
  527. return generation
  528. def post_process_generation(
  529. self,
  530. generation: str | list[str],
  531. fix_markdown: bool = True,
  532. num_workers: int | None = None,
  533. ) -> str | list[str]:
  534. """
  535. Postprocess a generated text or a list of generated texts.
  536. This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting.
  537. Postprocessing is quite slow so it is recommended to use multiprocessing to speed up the process.
  538. Args:
  539. generation (Union[str, list[str]]):
  540. The generated text or a list of generated texts.
  541. fix_markdown (`bool`, *optional*, defaults to `True`):
  542. Whether to perform Markdown formatting fixes.
  543. num_workers (`int`, *optional*):
  544. Optional number of workers to pass to leverage multiprocessing (postprocessing several texts in
  545. parallel).
  546. Returns:
  547. Union[str, list[str]]: The postprocessed text or list of postprocessed texts.
  548. """
  549. requires_backends(self, ["nltk", "levenshtein"])
  550. if isinstance(generation, list):
  551. if num_workers is not None and isinstance(num_workers, int):
  552. with Pool(num_workers) as p:
  553. return p.map(partial(self.post_process_single, fix_markdown=fix_markdown), generation)
  554. else:
  555. return [self.post_process_single(s, fix_markdown=fix_markdown) for s in generation]
  556. else:
  557. return self.post_process_single(generation, fix_markdown=fix_markdown)
  558. __all__ = ["NougatTokenizer"]