| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # referenced from
- # Library Name: torchtext
- # Authors: torchtext authors and @sluks
- # Date: 2020-07-18
- # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
- ##############
- # Copyright 2017--2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License"). You may not
- # use this file except in compliance with the License. A copy of the License
- # is located at
- #
- # http://aws.amazon.com/apache2.0/
- #
- # or in the "license" file accompanying this file. This file is distributed on
- # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- # express or implied. See the License for the specific language governing
- # permissions and limitations under the License.
- ##############
- # MIT License
- # Copyright (c) 2017 - Shujian Huang <huangsj@nju.edu.cn>
- import os
- import re
- import tempfile
- from collections.abc import Sequence
- from functools import partial
- from typing import Any, ClassVar, Optional
- import torch
- from torch import Tensor, tensor
- from typing_extensions import Literal
- from torchmetrics.functional.text.bleu import _bleu_score_compute, _bleu_score_update
- from torchmetrics.utilities.imports import (
- _IPADIC_AVAILABLE,
- _MECAB_AVAILABLE,
- _MECAB_KO_AVAILABLE,
- _MECAB_KO_DIC_AVAILABLE,
- _REGEX_AVAILABLE,
- _SENTENCEPIECE_AVAILABLE,
- )
- AVAILABLE_TOKENIZERS = ("none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200")
- _TokenizersLiteral = Literal["none", "13a", "zh", "intl", "char", "ja-mecab", "ko-mecab", "flores101", "flores200"]
- _UCODE_RANGES = (
- ("\u3400", "\u4db5"), # CJK Unified Ideographs Extension A, release 3.0
- ("\u4e00", "\u9fa5"), # CJK Unified Ideographs, release 1.1
- ("\u9fa6", "\u9fbb"), # CJK Unified Ideographs, release 4.1
- ("\uf900", "\ufa2d"), # CJK Compatibility Ideographs, release 1.1
- ("\ufa30", "\ufa6a"), # CJK Compatibility Ideographs, release 3.2
- ("\ufa70", "\ufad9"), # CJK Compatibility Ideographs, release 4.1
- ("\u20000", "\u2a6d6"), # (UTF16) CJK Unified Ideographs Extension B, release 3.1
- ("\u2f800", "\u2fa1d"), # (UTF16) CJK Compatibility Supplement, release 3.1
- ("\uff00", "\uffef"), # Full width ASCII, full width of English punctuation,
- # half width Katakana, half wide half width kana, Korean alphabet
- ("\u2e80", "\u2eff"), # CJK Radicals Supplement
- ("\u3000", "\u303f"), # CJK punctuation mark
- ("\u31c0", "\u31ef"), # CJK stroke
- ("\u2f00", "\u2fdf"), # Kangxi Radicals
- ("\u2ff0", "\u2fff"), # Chinese character structure
- ("\u3100", "\u312f"), # Phonetic symbols
- ("\u31a0", "\u31bf"), # Phonetic symbols (Taiwanese and Hakka expansion)
- ("\ufe10", "\ufe1f"),
- ("\ufe30", "\ufe4f"),
- ("\u2600", "\u26ff"),
- ("\u2700", "\u27bf"),
- ("\u3200", "\u32ff"),
- ("\u3300", "\u33ff"),
- )
- _FLORES_LOCAL_DIR = os.path.join(tempfile.gettempdir(), "torchmetrics-flores")
- # Model paths copied from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_spm.py.
- _FLORES_MODELS_URL = {
- "flores101": "https://dl.fbaipublicfiles.com/fairseq/models/flores/sacrebleu_tokenizer_spm.model",
- "flores200": "https://tinyurl.com/flores200sacrebleuspm",
- }
- class _SacreBLEUTokenizer:
- """Tokenizer used for SacreBLEU calculation.
- Source: https://github.com/mjpost/sacrebleu/tree/master/sacrebleu/tokenizers
- """
- _REGEX = (
- # language-dependent part (assuming Western languages)
- (re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "),
- # tokenize period and comma unless preceded by a digit
- (re.compile(r"([^0-9])([\.,])"), r"\1 \2 "),
- # tokenize period and comma unless followed by a digit
- (re.compile(r"([\.,])([^0-9])"), r" \1 \2"),
- # tokenize dash when preceded by a digit
- (re.compile(r"([0-9])(-)"), r"\1 \2 "),
- # one space only between words
- # NOTE: Doing this in Python (below) is faster
- # (re.compile(r'\s+'), r' '),
- )
- if _REGEX_AVAILABLE:
- import regex
- _INT_REGEX = (
- # Separate out punctuation preceded by a non-digit
- (regex.compile(r"(\P{N})(\p{P})"), r"\1 \2 "),
- # Separate out punctuation followed by a non-digit
- (regex.compile(r"(\p{P})(\P{N})"), r" \1 \2"),
- # Separate out symbols
- (regex.compile(r"(\p{S})"), r" \1 "),
- )
- _TOKENIZE_FN: ClassVar[dict] = {
- "none": "_tokenize_base",
- "13a": "_tokenize_13a",
- "zh": "_tokenize_zh",
- "intl": "_tokenize_international",
- "char": "_tokenize_char",
- "ja-mecab": "_tokenize_ja_mecab",
- "ko-mecab": "_tokenize_ko_mecab",
- "flores101": "_tokenize_flores_101",
- "flores200": "_tokenize_flores_200",
- }
- # Keep it as class variable to avoid initializing over and over again
- sentencepiece_processors: ClassVar[dict[str, Optional[Any]]] = {"flores101": None, "flores200": None}
- def __init__(self, tokenize: _TokenizersLiteral, lowercase: bool = False) -> None:
- self._check_tokenizers_validity(tokenize)
- self.tokenize_fn = getattr(self, self._TOKENIZE_FN[tokenize])
- self.lowercase = lowercase
- def __call__(self, line: str) -> Sequence[str]:
- tokenized_line = self.tokenize_fn(line)
- return self._lower(tokenized_line, self.lowercase).split()
- @classmethod
- def tokenize(
- cls: type["_SacreBLEUTokenizer"],
- line: str,
- tokenize: _TokenizersLiteral,
- lowercase: bool = False,
- ) -> Sequence[str]:
- cls._check_tokenizers_validity(tokenize)
- tokenize_fn = getattr(cls, cls._TOKENIZE_FN[tokenize])
- tokenized_line = tokenize_fn(line)
- return cls._lower(tokenized_line, lowercase).split()
- @classmethod
- def _tokenize_regex(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Post-processing tokenizer for `13a` and `zh` tokenizers.
- Args:
- line: a segment to tokenize
- Return:
- the tokenized line
- """
- for _re, repl in cls._REGEX:
- line = _re.sub(repl, line)
- # no leading or trailing spaces, single space within words
- return " ".join(line.split())
- @staticmethod
- def _is_chinese_char(uchar: str) -> bool:
- """Check if character is chinese.
- Args:
- uchar: input char in unicode.
- Return:
- whether the input char is a Chinese character.
- """
- return any(start <= uchar <= end for start, end in _UCODE_RANGES)
- @classmethod
- def _tokenize_base(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes an input line with the tokenizer.
- Args:
- line: a segment to tokenize
- Return:
- the tokenized line
- """
- return line
- @classmethod
- def _tokenize_13a(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes a line using a relatively minimal tokenization that is equivalent to mteval-v13a, used by WMT.
- Args:
- line: input sentence
- Return:
- tokenized sentence
- """
- # language-independent part:
- line = line.replace("<skipped>", "")
- line = line.replace("-\n", "")
- line = line.replace("\n", " ")
- if "&" in line:
- line = line.replace(""", '"')
- line = line.replace("&", "&")
- line = line.replace("<", "<")
- line = line.replace(">", ">")
- return cls._tokenize_regex(f" {line} ")
- @classmethod
- def _tokenize_zh(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenization of Chinese text.
- This is done in two steps: separate each Chinese characters (by utf-8 encoding) and afterwards tokenize the
- Chinese part (following the `13a` i.e. mteval tokenizer).
- Author: Shujian Huang huangsj@nju.edu.cn.
- Args:
- line: input sentence
- Return:
- tokenized sentence
- """
- line = line.strip()
- line_in_chars = ""
- for char in line:
- if cls._is_chinese_char(char):
- line_in_chars += " "
- line_in_chars += char
- line_in_chars += " "
- else:
- line_in_chars += char
- return cls._tokenize_regex(line_in_chars)
- @classmethod
- def _tokenize_international(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- r"""Tokenizes a string following the official BLEU implementation.
- See github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
- In our case, the input string is expected to be just one line.
- We just tokenize on punctuation and symbols,
- except when a punctuation is preceded and followed by a digit
- (e.g. a comma/dot as a thousand/decimal separator).
- We do not recover escaped forms of punctuation such as ' or >
- as these should never appear in MT system outputs (see issue #138)
- Note that a number (e.g., a year) followed by a dot at the end of
- sentence is NOT tokenized, i.e. the dot stays with the number because
- `s/(\\p{P})(\\P{N})/ $1 $2/g` does not match this case (unless we add a
- space after each sentence). However, this error is already in the
- original mteval-v14.pl and we want to be consistent with it.
- The error is not present in the non-international version,
- which uses `$norm_text = " $norm_text "`.
- Args:
- line: the input string to tokenize.
- Return:
- The tokenized string.
- """
- for _re, repl in cls._INT_REGEX:
- line = _re.sub(repl, line)
- return " ".join(line.split())
- @classmethod
- def _tokenize_char(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes all the characters in the input line.
- Args:
- line: a segment to tokenize
- Return:
- the tokenized line
- """
- return " ".join(char for char in line)
- @classmethod
- def _tokenize_ja_mecab(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes a Japanese string line using MeCab morphological analyzer.
- Args:
- line: the input string to tokenize.
- Return:
- The tokenized string.
- """
- import ipadic
- import MeCab
- tagger = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati")
- line = line.strip()
- return tagger.parse(line).strip()
- @classmethod
- def _tokenize_ko_mecab(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes a Korean string line using MeCab-korean morphological analyzer.
- Args:
- line: the input string to tokenize.
- Return:
- The tokenized string.
- """
- import mecab_ko
- import mecab_ko_dic
- tagger = mecab_ko.Tagger(mecab_ko_dic.MECAB_ARGS + " -Owakati")
- line = line.strip()
- return tagger.parse(line).strip()
- @classmethod
- def _tokenize_flores(
- cls: type["_SacreBLEUTokenizer"], line: str, tokenize: Literal["flores101", "flores200"]
- ) -> str:
- """Tokenizes a string line using sentencepiece tokenizer.
- Args:
- line: the input string to tokenize.
- tokenize: Tokenization technique to be used.
- Return:
- The tokenized string.
- """
- import sentencepiece
- if cls.sentencepiece_processors[tokenize] is None:
- cls.sentencepiece_processors[tokenize] = sentencepiece.SentencePieceProcessor()
- file_path = os.path.join(_FLORES_LOCAL_DIR, _FLORES_MODELS_URL[tokenize].split("/")[-1])
- if not os.path.exists(file_path):
- cls.download_flores_file(tokenize)
- cls.sentencepiece_processors[tokenize].Load(file_path) # type: ignore[union-attr]
- return " ".join(cls.sentencepiece_processors[tokenize].EncodeAsPieces(line)) # type: ignore[union-attr]
- @classmethod
- def _tokenize_flores_101(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-101`_ dataset.
- Args:
- line: the input string to tokenize.
- Return:
- The tokenized string.
- """
- return cls._tokenize_flores(line, "flores101")
- @classmethod
- def _tokenize_flores_200(cls: type["_SacreBLEUTokenizer"], line: str) -> str:
- """Tokenizes a string line using sentencepiece tokenizer according to `FLORES-200`_ dataset.
- Args:
- line: the input string to tokenize.
- Return:
- The tokenized string.
- """
- return cls._tokenize_flores(line, "flores200")
- @staticmethod
- def _lower(line: str, lowercase: bool) -> str:
- if lowercase:
- return line.lower()
- return line
- @classmethod
- def _check_tokenizers_validity(cls: type["_SacreBLEUTokenizer"], tokenize: _TokenizersLiteral) -> None:
- """Check if a supported tokenizer is chosen.
- Also check all dependencies of a given tokenizers are installed.
- """
- if tokenize not in cls._TOKENIZE_FN:
- raise ValueError(f"Unsupported tokenizer selected. Please, choose one of {list(cls._TOKENIZE_FN.keys())}")
- if tokenize == "intl" and not _REGEX_AVAILABLE:
- raise ModuleNotFoundError(
- "`'intl'` tokenization requires that `regex` is installed."
- " Use `pip install regex` or `pip install torchmetrics[text]`."
- )
- if tokenize == "ja-mecab" and not (_MECAB_AVAILABLE and _IPADIC_AVAILABLE):
- raise ModuleNotFoundError(
- "`'ja-mecab'` tokenization requires that `MeCab` and `ipadic` are installed."
- " Use `pip install mecab-python3 ipadic` or `pip install torchmetrics[text]`."
- )
- if tokenize == "ko-mecab" and not (_MECAB_KO_AVAILABLE and _MECAB_KO_DIC_AVAILABLE):
- raise ModuleNotFoundError(
- "`'ko-mecab'` tokenization requires that `mecab_ko` and `mecab_ko_dic` are installed."
- " Use `pip install mecab_ko mecab_ko_dic` or `pip install torchmetrics[text]`."
- )
- if "flores" in tokenize and not _SENTENCEPIECE_AVAILABLE:
- raise ModuleNotFoundError(
- "`'flores101' and 'flores200'` tokenizations require that `sentencepiece` is installed."
- " Use `pip install sentencepiece` or `pip install torchmetrics[text]`."
- )
- @staticmethod
- def download_flores_file(model_name: Literal["flores101", "flores200"]) -> None:
- """Download necessary files for `flores` tokenization via `sentencepiece`."""
- import ssl
- import urllib.request
- os.makedirs(_FLORES_LOCAL_DIR, exist_ok=True)
- model_url = _FLORES_MODELS_URL[model_name]
- file_path = os.path.join(_FLORES_LOCAL_DIR, model_url.split("/")[-1])
- try:
- with open(file_path, "wb") as out_file, urllib.request.urlopen(model_url) as remote_file:
- out_file.write(remote_file.read())
- except ssl.SSLError as e:
- raise OSError(f"Failed to download {model_name} model.") from e
- def sacre_bleu_score(
- preds: Sequence[str],
- target: Sequence[Sequence[str]],
- n_gram: int = 4,
- smooth: bool = False,
- tokenize: _TokenizersLiteral = "13a",
- lowercase: bool = False,
- weights: Optional[Sequence[float]] = None,
- ) -> Tensor:
- """Calculate `BLEU score`_ [1] of machine translated text with one or more references.
- This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu.
- .. note::
- In the original SacreBLEU, references are passed as a list of reference sets (grouped by reference index).
- In TorchMetrics, references are passed grouped per prediction (each prediction has its own list of references).
- For example::
- # Predictions
- preds = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
- # Original SacreBLEU:
- refs = [
- ['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'], # First set
- ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.'], # Second set
- ]
- # TorchMetrics SacreBLEU:
- target = [
- ['The dog bit the man.', 'The dog had bit the man.'], # References for first prediction
- ['It was not unexpected.', 'No one was surprised.'], # References for second prediction
- ['The man bit him first.', 'The man had bitten the dog.'], # References for third prediction
- ]
- Args:
- preds: An iterable of machine translated corpus
- target: An iterable of iterables of reference corpus
- n_gram: Gram value ranged from 1 to 4
- smooth: Whether to apply smoothing - see [2]
- tokenize: Tokenization technique to be used. Choose between ``'none'``, ``'13a'``, ``'zh'``, ``'intl'``,
- ``'char'``, ``'ja-mecab'``, ``'ko-mecab'``, ``'flores101'`` and ``'flores200'``.
- lowercase: If ``True``, BLEU score over lowercased text is calculated.
- weights:
- Weights used for unigrams, bigrams, etc. to calculate BLEU score.
- If not provided, uniform weights are used.
- Return:
- Tensor with BLEU Score
- Raises:
- ValueError: If ``preds`` and ``target`` corpus have different lengths.
- ValueError: If a length of a list of weights is not ``None`` and not equal to ``n_gram``.
- Example:
- >>> from torchmetrics.functional.text import sacre_bleu_score
- >>> preds = ['the cat is on the mat']
- >>> target = [['there is a cat on the mat', 'a cat is on the mat']]
- >>> sacre_bleu_score(preds, target)
- tensor(0.7598)
- References:
- [1] BLEU: a Method for Automatic Evaluation of Machine Translation by Papineni,
- Kishore, Salim Roukos, Todd Ward, and Wei-Jing Zhu `BLEU`_
- [2] A Call for Clarity in Reporting BLEU Scores by Matt Post.
- [3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence
- and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_
- """
- if len(preds) != len(target):
- raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}")
- if weights is not None and len(weights) != n_gram:
- raise ValueError(f"List of weights has different weights than `n_gram`: {len(weights)} != {n_gram}")
- if weights is None:
- weights = [1.0 / n_gram] * n_gram
- numerator = torch.zeros(n_gram)
- denominator = torch.zeros(n_gram)
- preds_len = tensor(0.0)
- target_len = tensor(0.0)
- tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase)
- preds_len, target_len = _bleu_score_update(
- preds,
- target,
- numerator,
- denominator,
- preds_len,
- target_len,
- n_gram,
- tokenize_fn,
- )
- return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, weights, smooth)
|