token_classification.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. import types
  2. import warnings
  3. from typing import Any, overload
  4. import numpy as np
  5. from ..models.bert.tokenization_bert_legacy import BasicTokenizer
  6. from ..utils import (
  7. ExplicitEnum,
  8. add_end_docstrings,
  9. is_torch_available,
  10. )
  11. from .base import ArgumentHandler, ChunkPipeline, Dataset, build_pipeline_init_args
  12. if is_torch_available():
  13. import torch
  14. from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
  15. class TokenClassificationArgumentHandler(ArgumentHandler):
  16. """
  17. Handles arguments for token classification.
  18. """
  19. def __call__(self, inputs: str | list[str], **kwargs):
  20. is_split_into_words = kwargs.get("is_split_into_words", False)
  21. delimiter = kwargs.get("delimiter")
  22. if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
  23. inputs = list(inputs)
  24. batch_size = len(inputs)
  25. elif isinstance(inputs, str):
  26. inputs = [inputs]
  27. batch_size = 1
  28. elif Dataset is not None and isinstance(inputs, Dataset) or isinstance(inputs, types.GeneratorType):
  29. return inputs, is_split_into_words, None, delimiter
  30. else:
  31. raise ValueError("At least one input is required.")
  32. offset_mapping = kwargs.get("offset_mapping")
  33. if offset_mapping:
  34. if isinstance(offset_mapping, list) and isinstance(offset_mapping[0], tuple):
  35. offset_mapping = [offset_mapping]
  36. if len(offset_mapping) != batch_size:
  37. raise ValueError("offset_mapping should have the same batch size as the input")
  38. return inputs, is_split_into_words, offset_mapping, delimiter
  39. class AggregationStrategy(ExplicitEnum):
  40. """All the valid aggregation strategies for TokenClassificationPipeline"""
  41. NONE = "none"
  42. SIMPLE = "simple"
  43. FIRST = "first"
  44. AVERAGE = "average"
  45. MAX = "max"
  46. @add_end_docstrings(
  47. build_pipeline_init_args(has_tokenizer=True),
  48. r"""
  49. ignore_labels (`list[str]`, defaults to `["O"]`):
  50. A list of labels to ignore.
  51. stride (`int`, *optional*):
  52. If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
  53. model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
  54. value of this argument defines the number of overlapping tokens between chunks. In other words, the model
  55. will shift forward by `tokenizer.model_max_length - stride` tokens each step.
  56. aggregation_strategy (`str`, *optional*, defaults to `"none"`):
  57. The strategy to fuse (or not) tokens based on the model prediction.
  58. - "none" : Will simply not do any aggregation and simply return raw results from the model
  59. - "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
  60. I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
  61. "entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
  62. different entities. On word based languages, we might end up splitting words undesirably : Imagine
  63. Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
  64. "NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
  65. that support that meaning, which is basically tokens separated by a space). These mitigations will
  66. only work on real words, "New york" might still be tagged with two different entities.
  67. - "first" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
  68. end up with different tags. Words will simply use the tag of the first token of the word when there
  69. is ambiguity.
  70. - "average" : (works only on word based models) Will use the `SIMPLE` strategy except that words,
  71. cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
  72. label is applied.
  73. - "max" : (works only on word based models) Will use the `SIMPLE` strategy except that words, cannot
  74. end up with different tags. Word entity will simply be the token with the maximum score.""",
  75. )
  76. class TokenClassificationPipeline(ChunkPipeline):
  77. """
  78. Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
  79. examples](../task_summary#named-entity-recognition) for more information.
  80. Example:
  81. ```python
  82. >>> from transformers import pipeline
  83. >>> token_classifier = pipeline(model="Jean-Baptiste/camembert-ner", aggregation_strategy="simple")
  84. >>> sentence = "Je m'appelle jean-baptiste et je vis à montréal"
  85. >>> tokens = token_classifier(sentence)
  86. >>> tokens
  87. [{'entity_group': 'PER', 'score': 0.9931, 'word': 'jean-baptiste', 'start': 12, 'end': 26}, {'entity_group': 'LOC', 'score': 0.998, 'word': 'montréal', 'start': 38, 'end': 47}]
  88. >>> token = tokens[0]
  89. >>> # Start and end provide an easy way to highlight words in the original text.
  90. >>> sentence[token["start"] : token["end"]]
  91. ' jean-baptiste'
  92. >>> # Some models use the same idea to do part of speech.
  93. >>> syntaxer = pipeline(model="vblagoje/bert-english-uncased-finetuned-pos", aggregation_strategy="simple")
  94. >>> syntaxer("My name is Sarah and I live in London")
  95. [{'entity_group': 'PRON', 'score': 0.999, 'word': 'my', 'start': 0, 'end': 2}, {'entity_group': 'NOUN', 'score': 0.997, 'word': 'name', 'start': 3, 'end': 7}, {'entity_group': 'AUX', 'score': 0.994, 'word': 'is', 'start': 8, 'end': 10}, {'entity_group': 'PROPN', 'score': 0.999, 'word': 'sarah', 'start': 11, 'end': 16}, {'entity_group': 'CCONJ', 'score': 0.999, 'word': 'and', 'start': 17, 'end': 20}, {'entity_group': 'PRON', 'score': 0.999, 'word': 'i', 'start': 21, 'end': 22}, {'entity_group': 'VERB', 'score': 0.998, 'word': 'live', 'start': 23, 'end': 27}, {'entity_group': 'ADP', 'score': 0.999, 'word': 'in', 'start': 28, 'end': 30}, {'entity_group': 'PROPN', 'score': 0.999, 'word': 'london', 'start': 31, 'end': 37}]
  96. ```
  97. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  98. This token recognition pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  99. `"ner"` (for predicting the classes of tokens in a sequence: person, organisation, location or miscellaneous).
  100. The models that this pipeline can use are models that have been fine-tuned on a token classification task. See the
  101. up-to-date list of available models on
  102. [huggingface.co/models](https://huggingface.co/models?filter=token-classification).
  103. """
  104. default_input_names = "sequences"
  105. _load_processor = False
  106. _load_image_processor = False
  107. _load_feature_extractor = False
  108. _load_tokenizer = True
  109. def __init__(self, args_parser=TokenClassificationArgumentHandler(), **kwargs):
  110. super().__init__(**kwargs)
  111. self.check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
  112. self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
  113. self._args_parser = args_parser
  114. def _sanitize_parameters(
  115. self,
  116. ignore_labels=None,
  117. aggregation_strategy: AggregationStrategy | None = None,
  118. offset_mapping: list[tuple[int, int]] | None = None,
  119. is_split_into_words: bool = False,
  120. stride: int | None = None,
  121. delimiter: str | None = None,
  122. ):
  123. preprocess_params = {}
  124. preprocess_params["is_split_into_words"] = is_split_into_words
  125. if is_split_into_words:
  126. preprocess_params["delimiter"] = " " if delimiter is None else delimiter
  127. if offset_mapping is not None:
  128. preprocess_params["offset_mapping"] = offset_mapping
  129. postprocess_params = {}
  130. if aggregation_strategy is not None:
  131. if isinstance(aggregation_strategy, str):
  132. aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
  133. if (
  134. aggregation_strategy
  135. in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
  136. and not self.tokenizer.is_fast
  137. ):
  138. raise ValueError(
  139. "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
  140. ' to `"simple"` or use a fast tokenizer.'
  141. )
  142. postprocess_params["aggregation_strategy"] = aggregation_strategy
  143. if ignore_labels is not None:
  144. postprocess_params["ignore_labels"] = ignore_labels
  145. if stride is not None:
  146. if stride >= self.tokenizer.model_max_length:
  147. raise ValueError(
  148. "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
  149. )
  150. if aggregation_strategy == AggregationStrategy.NONE:
  151. raise ValueError(
  152. "`stride` was provided to process all the text but `aggregation_strategy="
  153. f'"{aggregation_strategy}"`, please select another one instead.'
  154. )
  155. else:
  156. if self.tokenizer.is_fast:
  157. tokenizer_params = {
  158. "return_overflowing_tokens": True,
  159. "padding": True,
  160. "stride": stride,
  161. }
  162. preprocess_params["tokenizer_params"] = tokenizer_params
  163. else:
  164. raise ValueError(
  165. "`stride` was provided to process all the text but you're using a slow tokenizer."
  166. " Please use a fast tokenizer."
  167. )
  168. return preprocess_params, {}, postprocess_params
  169. @overload
  170. def __call__(self, inputs: str, **kwargs: Any) -> list[dict[str, str]]: ...
  171. @overload
  172. def __call__(self, inputs: list[str], **kwargs: Any) -> list[list[dict[str, str]]]: ...
  173. def __call__(self, inputs: str | list[str], **kwargs: Any) -> list[dict[str, str]] | list[list[dict[str, str]]]:
  174. """
  175. Classify each token of the text(s) given as inputs.
  176. Args:
  177. inputs (`str` or `List[str]`):
  178. One or several texts (or one list of texts) for token classification. Can be pre-tokenized when
  179. `is_split_into_words=True`.
  180. Return:
  181. A list or a list of list of `dict`: Each result comes as a list of dictionaries (one for each token in the
  182. corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy) with
  183. the following keys:
  184. - **word** (`str`) -- The token/word classified. This is obtained by decoding the selected tokens. If you
  185. want to have the exact string in the original sentence, use `start` and `end`.
  186. - **score** (`float`) -- The corresponding probability for `entity`.
  187. - **entity** (`str`) -- The entity predicted for that token/word (it is named *entity_group* when
  188. *aggregation_strategy* is not `"none"`.
  189. - **index** (`int`, only present when `aggregation_strategy="none"`) -- The index of the corresponding
  190. token in the sentence.
  191. - **start** (`int`, *optional*) -- The index of the start of the corresponding entity in the sentence. Only
  192. exists if the offsets are available within the tokenizer
  193. - **end** (`int`, *optional*) -- The index of the end of the corresponding entity in the sentence. Only
  194. exists if the offsets are available within the tokenizer
  195. """
  196. _inputs, is_split_into_words, offset_mapping, delimiter = self._args_parser(inputs, **kwargs)
  197. kwargs["is_split_into_words"] = is_split_into_words
  198. kwargs["delimiter"] = delimiter
  199. if is_split_into_words and not all(isinstance(input, list) for input in inputs):
  200. return super().__call__([inputs], **kwargs)
  201. if offset_mapping:
  202. kwargs["offset_mapping"] = offset_mapping
  203. return super().__call__(inputs, **kwargs)
  204. def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
  205. tokenizer_params = preprocess_params.pop("tokenizer_params", {})
  206. truncation = self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0
  207. word_to_chars_map = None
  208. is_split_into_words = preprocess_params["is_split_into_words"]
  209. if is_split_into_words:
  210. delimiter = preprocess_params["delimiter"]
  211. if not isinstance(sentence, list):
  212. raise ValueError("When `is_split_into_words=True`, `sentence` must be a list of tokens.")
  213. words = sentence
  214. sentence = delimiter.join(words) # Recreate the sentence string for later display and slicing
  215. # This map will allow to convert back word => char indices
  216. word_to_chars_map = []
  217. delimiter_len = len(delimiter)
  218. char_offset = 0
  219. for word in words:
  220. word_to_chars_map.append((char_offset, char_offset + len(word)))
  221. char_offset += len(word) + delimiter_len
  222. # We use `words` as the actual input for the tokenizer
  223. text_to_tokenize = words
  224. tokenizer_params["is_split_into_words"] = True
  225. else:
  226. if not isinstance(sentence, str):
  227. raise ValueError("When `is_split_into_words=False`, `sentence` must be an untokenized string.")
  228. text_to_tokenize = sentence
  229. inputs = self.tokenizer(
  230. text_to_tokenize,
  231. return_tensors="pt",
  232. truncation=truncation,
  233. return_special_tokens_mask=True,
  234. return_offsets_mapping=self.tokenizer.is_fast,
  235. **tokenizer_params,
  236. )
  237. if is_split_into_words and not self.tokenizer.is_fast:
  238. raise ValueError("is_split_into_words=True is only supported with fast tokenizers.")
  239. inputs.pop("overflow_to_sample_mapping", None)
  240. num_chunks = len(inputs["input_ids"])
  241. for i in range(num_chunks):
  242. model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
  243. if offset_mapping is not None:
  244. model_inputs["offset_mapping"] = offset_mapping
  245. model_inputs["sentence"] = sentence if i == 0 else None
  246. model_inputs["is_last"] = i == num_chunks - 1
  247. if word_to_chars_map is not None:
  248. model_inputs["word_ids"] = inputs.word_ids(i)
  249. model_inputs["word_to_chars_map"] = word_to_chars_map
  250. yield model_inputs
  251. def _forward(self, model_inputs):
  252. # Forward
  253. special_tokens_mask = model_inputs.pop("special_tokens_mask")
  254. offset_mapping = model_inputs.pop("offset_mapping", None)
  255. sentence = model_inputs.pop("sentence")
  256. is_last = model_inputs.pop("is_last")
  257. word_ids = model_inputs.pop("word_ids", None)
  258. word_to_chars_map = model_inputs.pop("word_to_chars_map", None)
  259. output = self.model(**model_inputs)
  260. logits = output["logits"] if isinstance(output, dict) else output[0]
  261. return {
  262. "logits": logits,
  263. "special_tokens_mask": special_tokens_mask,
  264. "offset_mapping": offset_mapping,
  265. "sentence": sentence,
  266. "is_last": is_last,
  267. "word_ids": word_ids,
  268. "word_to_chars_map": word_to_chars_map,
  269. **model_inputs,
  270. }
  271. def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
  272. if ignore_labels is None:
  273. ignore_labels = ["O"]
  274. all_entities = []
  275. # Get map from the first output, it's the same for all chunks
  276. word_to_chars_map = all_outputs[0].get("word_to_chars_map")
  277. for model_outputs in all_outputs:
  278. if model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
  279. logits = model_outputs["logits"][0].to(torch.float32).numpy()
  280. else:
  281. logits = model_outputs["logits"][0].numpy()
  282. sentence = all_outputs[0]["sentence"]
  283. input_ids = model_outputs["input_ids"][0]
  284. offset_mapping = (
  285. model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
  286. )
  287. special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
  288. word_ids = model_outputs.get("word_ids")
  289. maxes = np.max(logits, axis=-1, keepdims=True)
  290. shifted_exp = np.exp(logits - maxes)
  291. scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
  292. pre_entities = self.gather_pre_entities(
  293. sentence,
  294. input_ids,
  295. scores,
  296. offset_mapping,
  297. special_tokens_mask,
  298. aggregation_strategy,
  299. word_ids=word_ids,
  300. word_to_chars_map=word_to_chars_map,
  301. )
  302. grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
  303. # Filter anything that is in self.ignore_labels
  304. entities = [
  305. entity
  306. for entity in grouped_entities
  307. if entity.get("entity", None) not in ignore_labels
  308. and entity.get("entity_group", None) not in ignore_labels
  309. ]
  310. all_entities.extend(entities)
  311. num_chunks = len(all_outputs)
  312. if num_chunks > 1:
  313. all_entities = self.aggregate_overlapping_entities(all_entities)
  314. return all_entities
  315. def aggregate_overlapping_entities(self, entities):
  316. if len(entities) == 0:
  317. return entities
  318. entities = sorted(entities, key=lambda x: x["start"])
  319. aggregated_entities = []
  320. previous_entity = entities[0]
  321. for entity in entities:
  322. if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
  323. current_length = entity["end"] - entity["start"]
  324. previous_length = previous_entity["end"] - previous_entity["start"]
  325. if (
  326. current_length > previous_length
  327. or current_length == previous_length
  328. and entity["score"] > previous_entity["score"]
  329. ):
  330. previous_entity = entity
  331. else:
  332. aggregated_entities.append(previous_entity)
  333. previous_entity = entity
  334. aggregated_entities.append(previous_entity)
  335. return aggregated_entities
  336. def gather_pre_entities(
  337. self,
  338. sentence: str,
  339. input_ids: np.ndarray,
  340. scores: np.ndarray,
  341. offset_mapping: list[tuple[int, int]] | None,
  342. special_tokens_mask: np.ndarray,
  343. aggregation_strategy: AggregationStrategy,
  344. word_ids: list[int | None] | None = None,
  345. word_to_chars_map: list[tuple[int, int]] | None = None,
  346. ) -> list[dict]:
  347. """Fuse various numpy arrays into dicts with all the information needed for aggregation"""
  348. pre_entities = []
  349. for idx, token_scores in enumerate(scores):
  350. # Filter special_tokens
  351. if special_tokens_mask[idx]:
  352. continue
  353. word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
  354. if offset_mapping is not None:
  355. start_ind, end_ind = offset_mapping[idx]
  356. # If the input is pre-tokenized, we need to rescale the offsets to the absolute sentence.
  357. if word_ids is not None and word_to_chars_map is not None:
  358. word_index = word_ids[idx]
  359. if word_index is not None:
  360. start_char, _ = word_to_chars_map[word_index]
  361. start_ind += start_char
  362. end_ind += start_char
  363. if not isinstance(start_ind, int):
  364. start_ind = start_ind.item()
  365. end_ind = end_ind.item()
  366. word_ref = sentence[start_ind:end_ind]
  367. if getattr(self.tokenizer, "_tokenizer", None) and getattr(
  368. self.tokenizer._tokenizer.model, "continuing_subword_prefix", None
  369. ):
  370. # This is a BPE, word aware tokenizer, there is a correct way
  371. # to fuse tokens
  372. is_subword = len(word) != len(word_ref)
  373. else:
  374. # This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
  375. if aggregation_strategy in {
  376. AggregationStrategy.FIRST,
  377. AggregationStrategy.AVERAGE,
  378. AggregationStrategy.MAX,
  379. }:
  380. warnings.warn(
  381. "Tokenizer does not support real words, using fallback heuristic",
  382. UserWarning,
  383. )
  384. is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]
  385. if int(input_ids[idx]) == self.tokenizer.unk_token_id:
  386. word = word_ref
  387. is_subword = False
  388. else:
  389. start_ind = None
  390. end_ind = None
  391. is_subword = False
  392. pre_entity = {
  393. "word": word,
  394. "scores": token_scores,
  395. "start": start_ind,
  396. "end": end_ind,
  397. "index": idx,
  398. "is_subword": is_subword,
  399. }
  400. pre_entities.append(pre_entity)
  401. return pre_entities
  402. def aggregate(self, pre_entities: list[dict], aggregation_strategy: AggregationStrategy) -> list[dict]:
  403. if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
  404. entities = []
  405. for pre_entity in pre_entities:
  406. entity_idx = pre_entity["scores"].argmax()
  407. score = pre_entity["scores"][entity_idx]
  408. entity = {
  409. "entity": self.model.config.id2label[entity_idx],
  410. "score": score,
  411. "index": pre_entity["index"],
  412. "word": pre_entity["word"],
  413. "start": pre_entity["start"],
  414. "end": pre_entity["end"],
  415. }
  416. entities.append(entity)
  417. else:
  418. entities = self.aggregate_words(pre_entities, aggregation_strategy)
  419. if aggregation_strategy == AggregationStrategy.NONE:
  420. return entities
  421. return self.group_entities(entities)
  422. def aggregate_word(self, entities: list[dict], aggregation_strategy: AggregationStrategy) -> dict:
  423. word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
  424. if aggregation_strategy == AggregationStrategy.FIRST:
  425. scores = entities[0]["scores"]
  426. idx = scores.argmax()
  427. score = scores[idx]
  428. entity = self.model.config.id2label[idx]
  429. elif aggregation_strategy == AggregationStrategy.MAX:
  430. max_entity = max(entities, key=lambda entity: entity["scores"].max())
  431. scores = max_entity["scores"]
  432. idx = scores.argmax()
  433. score = scores[idx]
  434. entity = self.model.config.id2label[idx]
  435. elif aggregation_strategy == AggregationStrategy.AVERAGE:
  436. scores = np.stack([entity["scores"] for entity in entities])
  437. average_scores = np.nanmean(scores, axis=0)
  438. entity_idx = average_scores.argmax()
  439. entity = self.model.config.id2label[entity_idx]
  440. score = average_scores[entity_idx]
  441. else:
  442. raise ValueError("Invalid aggregation_strategy")
  443. new_entity = {
  444. "entity": entity,
  445. "score": score,
  446. "word": word,
  447. "start": entities[0]["start"],
  448. "end": entities[-1]["end"],
  449. }
  450. return new_entity
  451. def aggregate_words(self, entities: list[dict], aggregation_strategy: AggregationStrategy) -> list[dict]:
  452. """
  453. Override tokens from a given word that disagree to force agreement on word boundaries.
  454. Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
  455. company| B-ENT I-ENT
  456. """
  457. if aggregation_strategy in {
  458. AggregationStrategy.NONE,
  459. AggregationStrategy.SIMPLE,
  460. }:
  461. raise ValueError("NONE and SIMPLE strategies are invalid for word aggregation")
  462. word_entities = []
  463. word_group = None
  464. for entity in entities:
  465. if word_group is None:
  466. word_group = [entity]
  467. elif entity["is_subword"]:
  468. word_group.append(entity)
  469. else:
  470. word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
  471. word_group = [entity]
  472. # Last item
  473. if word_group is not None:
  474. word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
  475. return word_entities
  476. def group_sub_entities(self, entities: list[dict]) -> dict:
  477. """
  478. Group together the adjacent tokens with the same entity predicted.
  479. Args:
  480. entities (`dict`): The entities predicted by the pipeline.
  481. """
  482. # Get the first entity in the entity group
  483. entity = entities[0]["entity"].split("-", 1)[-1]
  484. scores = np.nanmean([entity["score"] for entity in entities])
  485. tokens = [entity["word"] for entity in entities]
  486. entity_group = {
  487. "entity_group": entity,
  488. "score": np.mean(scores),
  489. "word": self.tokenizer.convert_tokens_to_string(tokens),
  490. "start": entities[0]["start"],
  491. "end": entities[-1]["end"],
  492. }
  493. return entity_group
  494. def get_tag(self, entity_name: str) -> tuple[str, str]:
  495. if entity_name.startswith("B-"):
  496. bi = "B"
  497. tag = entity_name[2:]
  498. elif entity_name.startswith("I-"):
  499. bi = "I"
  500. tag = entity_name[2:]
  501. else:
  502. # It's not in B-, I- format
  503. # Default to I- for continuation.
  504. bi = "I"
  505. tag = entity_name
  506. return bi, tag
  507. def group_entities(self, entities: list[dict]) -> list[dict]:
  508. """
  509. Find and group together the adjacent tokens with the same entity predicted.
  510. Args:
  511. entities (`dict`): The entities predicted by the pipeline.
  512. """
  513. entity_groups = []
  514. entity_group_disagg = []
  515. for entity in entities:
  516. if not entity_group_disagg:
  517. entity_group_disagg.append(entity)
  518. continue
  519. # If the current entity is similar and adjacent to the previous entity,
  520. # append it to the disaggregated entity group
  521. # The split is meant to account for the "B" and "I" prefixes
  522. # Shouldn't merge if both entities are B-type
  523. bi, tag = self.get_tag(entity["entity"])
  524. last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
  525. if tag == last_tag and bi != "B":
  526. # Modify subword type to be previous_type
  527. entity_group_disagg.append(entity)
  528. else:
  529. # If the current entity is different from the previous entity
  530. # aggregate the disaggregated entity group
  531. entity_groups.append(self.group_sub_entities(entity_group_disagg))
  532. entity_group_disagg = [entity]
  533. if entity_group_disagg:
  534. # it's the last entity, add it to the entity groups
  535. entity_groups.append(self.group_sub_entities(entity_group_disagg))
  536. return entity_groups
  537. NerPipeline = TokenClassificationPipeline