| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660 |
- # Copyright 2023 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Tokenizer class for Nougat.
- """
- import re
- from functools import partial
- from multiprocessing import Pool
- import numpy as np
- from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
- from tokenizers.models import BPE
- from ...tokenization_utils_tokenizers import TokenizersBackend
- from ...utils import is_levenshtein_available, is_nltk_available, logging, requires_backends
- if is_levenshtein_available():
- from Levenshtein import ratio
- if is_nltk_available():
- import nltk
- logger = logging.get_logger(__name__)
- VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
- def markdown_compatible(text: str) -> str:
- """
- Make text compatible with Markdown formatting.
- This function makes various text formatting adjustments to make it compatible with Markdown.
- Args:
- text (`str`):
- The input text to be made Markdown-compatible.
- Returns:
- `str`: The Markdown-compatible text.
- """
- # equation tag
- # Replace lines that start with a pattern like (decimal) \[some text\] with \[[some text] \tag{decimal}\].
- text = re.sub(r"^\(([\d.]+[a-zA-Z]?)\) \\\[(.+?)\\\]$", r"\[\2 \\tag{\1}\]", text, flags=re.MULTILINE)
- # Replace lines that start with a pattern like \[some text\] (decimal) with \[[some text] \tag{decimal}\].
- text = re.sub(r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\)$", r"\[\1 \\tag{\2}\]", text, flags=re.MULTILINE)
- # Replace lines that start with a pattern like \[some text\] (digits) \[another text\] with \[[some text] \tag{digits}\] [another text].
- text = re.sub(
- r"^\\\[(.+?)\\\] \(([\d.]+[a-zA-Z]?)\) (\\\[.+?\\\])$",
- r"\[\1 \\tag{\2}\] \3",
- text,
- flags=re.MULTILINE,
- )
- # multi line
- text = text.replace(r"\. ", ". ")
- # bold formatting
- text = text.replace(r"\bm{", r"\mathbf{").replace(r"{\\bm ", r"\mathbf{")
- text = re.sub(r"\\mbox{ ?\\boldmath\$(.*?)\$}", r"\\mathbf{\1}", text)
- # Reformat urls (http, ftp and https only) to markdown [url](url) clickable format
- text = re.sub(
- r"((?:http|ftp|https):\/\/(?:[\w_-]+(?:(?:\.[\w_-]+)+))(?:[\w.,@?^=%&:\/~+#-]*[\w@?^=%&\/~+#-]))",
- r"[\1](\1)",
- text,
- )
- # algorithms
- text = re.sub(r"```\s*(.+?)\s*```", r"```\n\1\n```", text, flags=re.DOTALL)
- return text
- def normalize_list_like_lines(generation):
- """
- Normalize lines in the given text that resemble list items. The function looks for lines that start optionally with
- '-' or '*', possibly followed by Roman numerals or digits indicating nesting levels. The function reformats such
- lines to make them more structured.
- Args:
- generation (str): The input text containing lines that need to be normalized.
- Returns:
- str: The input text with the list-like lines normalized.
- Note:
- The function uses regular expressions to identify and reformat the list-like lines. The patterns capture
- optional bullet points, nesting levels indicated by numerals, and the actual list item content. The
- normalization adjusts the bullet point style and nesting levels based on the captured patterns.
- """
- lines = generation.split("\n")
- output_lines = []
- for line_no, line in enumerate(lines):
- match = re.search(r". ([-*]) ", line)
- if not match or line[0] not in ("-", "*"):
- output_lines.append(line)
- continue # Doesn't fit the pattern we want, no changes
- delim = match.group(1) + " "
- splits = line.split(delim)[1:]
- replacement = ""
- delim1 = line[0] + " "
- for i, item in enumerate(splits):
- level = 0
- potential_numeral, _, rest = item.strip().partition(" ")
- if not rest:
- continue
- # Infer current nesting level based on detected numbering
- if re.match(r"^[\dixv]+((?:\.[\dixv])?)+$", potential_numeral, flags=re.IGNORECASE | re.MULTILINE):
- level = potential_numeral.count(".")
- replacement += (
- ("\n" if i > 0 else "") + ("\t" * level) + (delim if i > 0 or line_no == 0 else delim1) + item.strip()
- )
- if line_no == len(lines) - 1: # If this is the last line in the generation
- replacement += "\n" # Add an empty line to the end of the generation
- output_lines.append(replacement)
- return "\n".join(output_lines)
- def find_next_punctuation(text: str, start_idx=0):
- """
- Find the index of the next punctuation mark.
- Args:
- text (`str`):
- String to examine
- start_idx (`int`, *optional*)
- Index where to start
- """
- for i in range(start_idx, len(text)):
- if text[i] in [".", "?", "!", "\n"]:
- return i
- return None
- def truncate_repetitions(text: str, min_len: int = 30) -> str:
- """
- Attempt to truncate repeating segments in the input string.
- This function looks for the longest repeating substring at the end of the input string and truncates it to appear
- only once. To be considered for removal, repetitions need to be continuous.
- Args:
- text (`str`):
- The input raw prediction to be truncated.
- min_len (int):
- The minimum length of the repeating segment.
- Returns:
- `str`: The input string with repeated segments truncated.
- """
- text_lower = text.lower()
- text_length = len(text_lower)
- if text_length < 2 * min_len:
- return text
- # try to find a length at which the tail is repeating
- max_repetition_length = None
- for repetition_length in range(min_len, int(text_length / 2)):
- # check if there is a repetition at the end
- same = True
- for i in range(0, repetition_length):
- if text_lower[text_length - repetition_length - i - 1] != text_lower[text_length - i - 1]:
- same = False
- break
- if same:
- max_repetition_length = repetition_length
- if max_repetition_length is None:
- return text
- lcs = text_lower[-max_repetition_length:]
- # remove all but the last repetition
- substituted_text = text
- substituted_text_lower = text_lower
- while substituted_text_lower.endswith(lcs):
- substituted_text = substituted_text[:-max_repetition_length]
- substituted_text_lower = substituted_text_lower[:-max_repetition_length]
- # this is the tail with the repetitions
- repeating_tail = text_lower[len(substituted_text_lower) :]
- # add until next punctuation and make sure last sentence is not repeating
- substituted_text_lower_out = substituted_text_lower
- while True:
- sentence_end = find_next_punctuation(text_lower, len(substituted_text_lower_out))
- sentence_start = find_next_punctuation(text_lower[::-1], len(substituted_text_lower_out))
- if sentence_end and sentence_start:
- sentence = text_lower[sentence_start:sentence_end]
- substituted_text_lower_out = text_lower[: sentence_end + 1]
- if sentence in repeating_tail:
- break
- else:
- break
- text_out = text[: len(substituted_text_lower_out)]
- return text_out
- def remove_numbers(lines):
- def _clean(s):
- return re.sub(r"(?:[\d_]|\*\*)", "", s).strip()
- if isinstance(lines, str):
- return _clean(lines)
- out = []
- for l in lines:
- out.append(_clean(l))
- return out
- def get_slices(lines, clean_lines):
- """
- Get slices of text based on specific criteria within the lines.
- This function identifies and returns slices of text from the input lines based on certain conditions.
- These conditions were chosen by the Nougat authors:
- - The slice is less than 200 characters long.
- - The slice is more than 3 characters long.
- - The slice does not start with "[MISSING_PAGE".
- - The slice is either the same as the next slice or the ratio of the two in terms of Levenshtein distance is
- greater than 0.9.
- Args:
- lines (`list[str]`):
- The list of lines containing the text.
- clean_lines (`list[str]`):
- A cleaned version of the text (without numbers).
- Returns:
- `list[tuple]`: A list of tuples representing the start and end indices of text slices.
- """
- indices = np.zeros(len(lines))
- for i in range(len(lines) - 1):
- j = i + 1
- while not clean_lines[j] and j < len(lines) - 1:
- j += 1
- if (
- len(clean_lines[i]) < 200
- and len(clean_lines[i]) > 3
- and len(clean_lines[j]) < 200
- and len(clean_lines[j]) > 3
- and not clean_lines[i].startswith("[MISSING_PAGE")
- and (clean_lines[i] == clean_lines[j] or ratio(clean_lines[i], clean_lines[j]) > 0.9)
- ):
- indices[i:j] = 1
- ids = np.where(indices)[0]
- slices = []
- if len(ids) == 0:
- return slices
- j0 = 0
- for j, x in enumerate(np.diff(ids) > 3):
- if x:
- slices.append((ids[j0], ids[j] + 2))
- j0 = j + 1
- slices.append((ids[j0], ids[-1] + 2))
- return [sli for sli in slices if sli[1] - sli[0] > 15]
- def remove_slice_from_lines(lines, clean_text, slice) -> str:
- """
- Remove a slice of text from the lines based on specific criteria.
- This function identifies a slice of text within the lines and removes it based on certain conditions.
- Args:
- lines (list of str): The list of lines containing the text.
- clean_text (list of str): A cleaned version of the text (without numbers).
- slice (tuple): A tuple representing the start and end indices of the slice to be removed.
- Returns:
- str: The removed slice of text as a single string.
- """
- base = clean_text[slice[0]]
- section = list(slice)
- check_start_flag = False
- # backwards pass, at most 5 lines
- for line_idx in range(max(0, slice[0] - 1), max(0, slice[0] - 5), -1):
- if not lines[line_idx]:
- continue
- if lines[line_idx] == "## References":
- section[0] = line_idx
- break
- elif ratio(base, remove_numbers(lines[line_idx])) < 0.9:
- section[0] = line_idx + 1
- potential_ref = remove_numbers(lines[max(0, line_idx - 1)].partition("* [")[-1])
- if len(potential_ref) >= 0.75 * len(base) and ratio(base, potential_ref) < 0.9:
- section[0] = line_idx
- check_start_flag = True
- break
- # forward pass, at most 5 lines
- for line_idx in range(min(len(lines), slice[1]), min(len(lines), slice[1] + 5)):
- if ratio(base, remove_numbers(lines[line_idx])) < 0.9:
- section[1] = line_idx
- break
- if len(lines) <= section[1]:
- section[1] = len(lines) - 1
- to_delete = "\n".join(lines[section[0] : section[1] + 1])
- # cut off next page content
- itera, iterb = enumerate(lines[section[1] - 1]), enumerate(lines[section[1]])
- while True:
- try:
- (ia, a) = next(itera)
- while a.isnumeric():
- (ia, a) = next(itera)
- (ib, b) = next(iterb)
- while b.isnumeric():
- (ib, b) = next(iterb)
- if a != b:
- break
- except StopIteration:
- break
- if check_start_flag and "* [" in to_delete:
- to_delete = "* [" + to_delete.partition("* [")[-1]
- try:
- delta = len(lines[section[1]]) - ib - 1
- if delta > 0:
- to_delete = to_delete[:-delta]
- except UnboundLocalError:
- pass
- return to_delete.strip()
- class NougatTokenizer(TokenizersBackend):
- """
- Tokenizer for Nougat (backed by HuggingFace tokenizers library).
- This tokenizer inherits from [`TokenizersBackend`] which contains most of the main methods. Users should
- refer to this superclass for more information regarding those methods. This class mainly adds Nougat-specific
- methods for postprocessing the generated text.
- Args:
- vocab_file (`str`, *optional*):
- Path to the vocabulary file.
- merges_file (`str`, *optional*):
- Path to the merges file.
- tokenizer_file (`str`, *optional*):
- [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
- contains everything needed to load the tokenizer.
- clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
- Whether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
- spaces.
- unk_token (`str`, *optional*, defaults to `"<unk>"`):
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
- token instead.
- bos_token (`str`, *optional*, defaults to `"<s>"`):
- The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
- eos_token (`str`, *optional*, defaults to `"</s>"`):
- The end of sequence token.
- pad_token (`str`, *optional*, defaults to `"<pad>"`):
- The token used for padding, for example when batching sequences of different lengths.
- vocab (`str`, `dict` or `list`, *optional*):
- Custom vocabulary dictionary. If not provided, vocabulary is loaded from vocab_file.
- merges (`str` or `list`, *optional*):
- Custom merges list. If not provided, merges are loaded from merges_file.
- """
- vocab_files_names = VOCAB_FILES_NAMES
- model_input_names = ["input_ids", "attention_mask"]
- model = BPE
- def __init__(
- self,
- errors: str = "replace",
- unk_token: str = "<unk>",
- bos_token: str = "<s>",
- eos_token: str = "</s>",
- pad_token: str = "<pad>",
- vocab: str | dict | list | None = None,
- merges: str | list | None = None,
- **kwargs,
- ):
- self._vocab = (
- vocab
- if vocab is not None
- else {
- str(bos_token): 0,
- str(pad_token): 1,
- str(eos_token): 2,
- str(unk_token): 3,
- "[START_REF]": 4,
- }
- )
- self._merges = merges or []
- self._tokenizer = Tokenizer(
- BPE(
- vocab=self._vocab,
- merges=self._merges,
- dropout=None,
- continuing_subword_prefix="",
- end_of_word_suffix="",
- fuse_unk=False,
- )
- )
- self._tokenizer.normalizer = normalizers.NFKC()
- self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(pattern="SPL1T-TH1S-Pl3A5E", behavior="removed", invert=False),
- pre_tokenizers.Digits(individual_digits=True),
- pre_tokenizers.Split(
- pattern=r"[\(\)\[\]\{\}]|([!\"#\$%\&'\*\+,\-\./:;<=>\?\\\^_`\|\~])\1*",
- behavior="isolated",
- invert=False,
- ),
- pre_tokenizers.Split(pattern="\n", behavior="isolated", invert=False),
- pre_tokenizers.ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True),
- ]
- )
- self._tokenizer.decoder = decoders.ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True)
- super().__init__(
- errors=errors,
- unk_token=unk_token,
- bos_token=bos_token,
- eos_token=eos_token,
- pad_token=pad_token,
- **kwargs,
- )
- self._tokenizer.post_processor = processors.TemplateProcessing(
- single=f"{bos_token}:0 $A:0 {eos_token}:0",
- pair="$A:0 $B:1",
- special_tokens=[
- (str(eos_token), self.eos_token_id),
- (str(bos_token), self.bos_token_id),
- ],
- )
- # Enable truncation and padding
- self._tokenizer.enable_truncation(max_length=4096)
- self._tokenizer.enable_padding(length=4096, pad_id=self.pad_token_id, pad_token=str(pad_token))
- def remove_hallucinated_references(self, text: str) -> str:
- """
- Remove hallucinated or missing references from the text.
- This function identifies and removes references that are marked as missing or hallucinated from the input text.
- Args:
- text (`str`):
- The input text containing references.
- Returns:
- `str`: The text with hallucinated references removed.
- """
- lines = text.split("\n")
- if len(lines) == 0:
- return ""
- clean_lines = remove_numbers(lines)
- slices = get_slices(lines, clean_lines)
- to_delete = []
- for slice in slices:
- to_delete.append(remove_slice_from_lines(lines, clean_lines, slice))
- for to_delete in reversed(to_delete):
- text = text.replace(to_delete, "\n\n[MISSING_PAGE_POST]\n\n")
- text = re.sub(
- r"## References\n+\[MISSING_PAGE_POST(:\d+)?\]",
- "\n\n[MISSING_PAGE_POST\\1]",
- text,
- )
- return text
- def correct_tables(self, generation: str) -> str:
- """
- Takes a generated string and fixes tables/tabulars to make them match the markdown format needed.
- Args:
- generation (str): The generated text to be postprocessed.
- Returns:
- str: The postprocessed text.
- Example:
- ```python
- correct_tables("\\begin{table} \\begin{tabular}{l l} & \\ \\end{tabular} \\end{table}")
- "\\begin{table}\n\\begin{tabular}{l l} & \\ \\end{tabular}\n\\end{table}"
- ```
- """
- # remove obvious wrong tables
- for l in generation.split("\n"):
- if l.count("\\begin{tabular}") > 15 or l.count("\\multicolumn") > 60 or l.count("&") > 400:
- generation = generation.replace(l, "")
- # whitespace corrections
- generation = generation.replace("\\begin{table} \\begin{tabular}", "\\begin{table}\n\\begin{tabular}")
- generation = generation.replace("\\end{tabular} \\end{table}", "\\end{tabular}\n\\end{table}")
- generation = generation.replace("\\end{table} Tab", "\\end{table}\nTab")
- generation = re.sub(r"(^.+)\\begin{tab", r"\1\n\\begin{tab", generation, flags=re.MULTILINE)
- # Remove left-aligned empty LaTeX tabular blocks.
- generation = generation.replace(r"\begin{tabular}{l l} & \\ \end{tabular}", "")
- # Remove tabulars with just 2 newline characters.
- generation = generation.replace("\\begin{tabular}{}\n\n\\end{tabular}", "")
- return generation
- def post_process_single(self, generation: str, fix_markdown: bool = True) -> str:
- """
- Postprocess a single generated text. Regular expressions used here are taken directly from the Nougat article
- authors. These expressions are commented for clarity and tested end-to-end in most cases.
- Args:
- generation (str): The generated text to be postprocessed.
- fix_markdown (bool, optional): Whether to perform Markdown formatting fixes. Default is True.
- Returns:
- str: The postprocessed text.
- """
- generation = re.sub(
- r"(?:\n|^)#+ \d*\W? ?(.{100,})", r"\n\1", generation
- ) # too long section titles probably are none
- generation = generation.strip()
- # Remove LaTeX left margin tag
- generation = generation.replace("\n* [leftmargin=*]\n", "\n")
- # Remove lines with markdown headings starting with #, with numerals,
- # and possibly roman numerals with trailing spaces and newlines
- generation = re.sub(r"^#+ (?:[\d+\.]+|[ixv\.]+)?\s*(?:$|\n\s*)", "", generation, flags=re.MULTILINE)
- # most likely hallucinated titles
- lines = generation.split("\n")
- if lines[-1].startswith("#") and lines[-1].lstrip("#").startswith(" ") and len(lines) > 1:
- logger.info("Likely hallucinated title at the end of the page: " + lines[-1])
- generation = "\n".join(lines[:-1])
- # obvious repetition detection
- generation = truncate_repetitions(generation)
- # Reference corrections
- generation = self.remove_hallucinated_references(generation)
- # Remove lines starting with asterisks and numbers like "*[1]" and followed by capital letters and periods (ie too long references)
- generation = re.sub(r"^\* \[\d+\](\s?[A-W]\.+\s?){10,}.*$", "", generation, flags=re.MULTILINE)
- # Remove empty brackets after a reference number in brackets. *[12][]ABC will become *[12]ABC
- generation = re.sub(r"^(\* \[\d+\])\[\](.*)$", r"\1\2", generation, flags=re.MULTILINE)
- # Remove single characters before or after 2 new lines
- generation = re.sub(r"(^\w\n\n|\n\n\w$)", "", generation)
- # pmc math artifact correction
- generation = re.sub(
- r"([\s.,()])_([a-zA-Z0-9])__([a-zA-Z0-9]){1,3}_([\s.,:()])",
- r"\1\(\2_{\3}\)\4",
- generation,
- )
- generation = re.sub(r"([\s.,\d])_([a-zA-Z0-9])_([\s.,\d;])", r"\1\(\2\)\3", generation)
- # footnote mistakes
- generation = re.sub(
- r"(\nFootnote .*?:) (?:footnotetext|thanks):\W*(.*(?:\n\n|$))",
- r"\1 \2",
- generation,
- )
- # TODO Come up with footnote formatting inside a table
- generation = re.sub(r"\[FOOTNOTE:.+?\](.*?)\[ENDFOOTNOTE\]", "", generation)
- # itemize post processing
- generation = normalize_list_like_lines(generation)
- if generation.endswith((".", "}")):
- generation += "\n\n"
- if re.match(r"[A-Z0-9,;:]$", generation):
- # add space in case it there is a comma or word ending
- generation += " "
- elif generation.startswith(("#", "**", "\\begin")):
- generation = "\n\n" + generation
- elif generation.split("\n")[-1].startswith(("#", "Figure", "Table")):
- generation = generation + "\n\n"
- else:
- try:
- last_word = generation.split(" ")[-1]
- if last_word in nltk.corpus.words.words():
- generation += " "
- except LookupError:
- # add space just in case. Will split words but better than concatenating them
- generation += " "
- # table corrections
- generation = self.correct_tables(generation)
- # Remove optional, empty square brackets after begin{array}
- generation = generation.replace("\\begin{array}[]{", "\\begin{array}{")
- # Remove empty or malformed LaTeX tabular blocks with 2 or more columns specified, with spaces and ampersands.
- generation = re.sub(
- r"\\begin{tabular}{([clr ]){2,}}\s*[& ]*\s*(\\\\)? \\end{tabular}",
- "",
- generation,
- )
- # Remove lines containing "S.A.B." one or more times. Was included in Nougat's code.
- generation = re.sub(r"(\*\*S\. A\. B\.\*\*\n+){2,}", "", generation)
- # Remove markdown-style headers that are incomplete or empty on multiple lines.
- generation = re.sub(r"^#+( [\[\d\w])?$", "", generation, flags=re.MULTILINE)
- # Remove lines with just one period.
- generation = re.sub(r"^\.\s*$", "", generation, flags=re.MULTILINE)
- # Replace instances of three or more newlines with just two newlines.
- generation = re.sub(r"\n{3,}", "\n\n", generation)
- if fix_markdown:
- return markdown_compatible(generation)
- else:
- return generation
- def post_process_generation(
- self,
- generation: str | list[str],
- fix_markdown: bool = True,
- num_workers: int | None = None,
- ) -> str | list[str]:
- """
- Postprocess a generated text or a list of generated texts.
- This function can be used to perform postprocessing on generated text, such as fixing Markdown formatting.
- Postprocessing is quite slow so it is recommended to use multiprocessing to speed up the process.
- Args:
- generation (Union[str, list[str]]):
- The generated text or a list of generated texts.
- fix_markdown (`bool`, *optional*, defaults to `True`):
- Whether to perform Markdown formatting fixes.
- num_workers (`int`, *optional*):
- Optional number of workers to pass to leverage multiprocessing (postprocessing several texts in
- parallel).
- Returns:
- Union[str, list[str]]: The postprocessed text or list of postprocessed texts.
- """
- requires_backends(self, ["nltk", "levenshtein"])
- if isinstance(generation, list):
- if num_workers is not None and isinstance(num_workers, int):
- with Pool(num_workers) as p:
- return p.map(partial(self.post_process_single, fix_markdown=fix_markdown), generation)
- else:
- return [self.post_process_single(s, fix_markdown=fix_markdown) for s in generation]
- else:
- return self.post_process_single(generation, fix_markdown=fix_markdown)
- __all__ = ["NougatTokenizer"]
|