zero_shot_classification.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import inspect
  2. import numpy as np
  3. from ..tokenization_python import TruncationStrategy
  4. from ..utils import add_end_docstrings, logging
  5. from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args
  6. logger = logging.get_logger(__name__)
  7. class ZeroShotClassificationArgumentHandler(ArgumentHandler):
  8. """
  9. Handles arguments for zero-shot for text classification by turning each possible label into an NLI
  10. premise/hypothesis pair.
  11. """
  12. def _parse_labels(self, labels):
  13. if isinstance(labels, str):
  14. labels = [label.strip() for label in labels.split(",") if label.strip()]
  15. return labels
  16. def __call__(self, sequences, labels, hypothesis_template):
  17. if len(labels) == 0 or len(sequences) == 0:
  18. raise ValueError("You must include at least one label and at least one sequence.")
  19. if hypothesis_template.format(labels[0]) == hypothesis_template:
  20. raise ValueError(
  21. f'The provided hypothesis_template "{hypothesis_template}" was not able to be formatted with the target labels. '
  22. "Make sure the passed template includes formatting syntax such as {} where the label should go."
  23. )
  24. if isinstance(sequences, str):
  25. sequences = [sequences]
  26. sequence_pairs = []
  27. for sequence in sequences:
  28. sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])
  29. return sequence_pairs, sequences
  30. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  31. class ZeroShotClassificationPipeline(ChunkPipeline):
  32. """
  33. NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural
  34. language inference) tasks. Equivalent of `text-classification` pipelines, but these models don't require a
  35. hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is
  36. **much** more flexible.
  37. Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
  38. pair and passed to the pretrained model. Then, the logit for *entailment* is taken as the logit for the candidate
  39. label being valid. Any NLI model can be used, but the id of the *entailment* label must be included in the model
  40. config's :attr:*~transformers.PreTrainedConfig.label2id*.
  41. Example:
  42. ```python
  43. >>> from transformers import pipeline
  44. >>> oracle = pipeline(model="facebook/bart-large-mnli")
  45. >>> oracle(
  46. ... "I have a problem with my iphone that needs to be resolved asap!!",
  47. ... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
  48. ... )
  49. {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'], 'scores': [0.504, 0.479, 0.013, 0.003, 0.002]}
  50. >>> oracle(
  51. ... "I have a problem with my iphone that needs to be resolved asap!!",
  52. ... candidate_labels=["english", "german"],
  53. ... )
  54. {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['english', 'german'], 'scores': [0.814, 0.186]}
  55. ```
  56. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  57. This NLI pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  58. `"zero-shot-classification"`.
  59. The models that this pipeline can use are models that have been fine-tuned on an NLI task. See the up-to-date list
  60. of available models on [huggingface.co/models](https://huggingface.co/models?search=nli).
  61. """
  62. _load_processor = False
  63. _load_image_processor = False
  64. _load_feature_extractor = False
  65. _load_tokenizer = True
  66. def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), **kwargs):
  67. self._args_parser = args_parser
  68. super().__init__(**kwargs)
  69. if self.entailment_id == -1:
  70. logger.warning(
  71. "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
  72. "-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
  73. )
  74. @property
  75. def entailment_id(self):
  76. for label, ind in self.model.config.label2id.items():
  77. if label.lower().startswith("entail"):
  78. return ind
  79. return -1
  80. def _parse_and_tokenize(
  81. self, sequence_pairs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.ONLY_FIRST, **kwargs
  82. ):
  83. """
  84. Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
  85. """
  86. return_tensors = "pt"
  87. if self.tokenizer.pad_token is None:
  88. # Override for tokenizers not supporting padding
  89. logger.error(
  90. "Tokenizer was not supporting padding necessary for zero-shot, attempting to use "
  91. " `pad_token=eos_token`"
  92. )
  93. self.tokenizer.pad_token = self.tokenizer.eos_token
  94. try:
  95. inputs = self.tokenizer(
  96. sequence_pairs,
  97. add_special_tokens=add_special_tokens,
  98. return_tensors=return_tensors,
  99. padding=padding,
  100. truncation=truncation,
  101. )
  102. except Exception as e:
  103. if "too short" in str(e):
  104. # tokenizers might yell that we want to truncate
  105. # to a value that is not even reached by the input.
  106. # In that case we don't want to truncate.
  107. # It seems there's not a really better way to catch that
  108. # exception.
  109. inputs = self.tokenizer(
  110. sequence_pairs,
  111. add_special_tokens=add_special_tokens,
  112. return_tensors=return_tensors,
  113. padding=padding,
  114. truncation=TruncationStrategy.DO_NOT_TRUNCATE,
  115. )
  116. else:
  117. raise e
  118. return inputs
  119. def _sanitize_parameters(self, **kwargs):
  120. preprocess_params = {}
  121. if "candidate_labels" in kwargs:
  122. preprocess_params["candidate_labels"] = self._args_parser._parse_labels(kwargs["candidate_labels"])
  123. if "hypothesis_template" in kwargs:
  124. preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
  125. postprocess_params = {}
  126. if "multi_label" in kwargs:
  127. postprocess_params["multi_label"] = kwargs["multi_label"]
  128. return preprocess_params, {}, postprocess_params
  129. def __call__(
  130. self,
  131. sequences: str | list[str],
  132. *args,
  133. **kwargs,
  134. ):
  135. """
  136. Classify the sequence(s) given as inputs. See the [`ZeroShotClassificationPipeline`] documentation for more
  137. information.
  138. Args:
  139. sequences (`str` or `list[str]`):
  140. The sequence(s) to classify, will be truncated if the model input is too large.
  141. candidate_labels (`str` or `list[str]`):
  142. The set of possible class labels to classify each sequence into. Can be a single label, a string of
  143. comma-separated labels, or a list of labels.
  144. hypothesis_template (`str`, *optional*, defaults to `"This example is {}."`):
  145. The template used to turn each label into an NLI-style hypothesis. This template must include a {} or
  146. similar syntax for the candidate label to be inserted into the template. For example, the default
  147. template is `"This example is {}."` With the candidate label `"sports"`, this would be fed into the
  148. model like `"<cls> sequence to classify <sep> This example is sports . <sep>"`. The default template
  149. works well in many cases, but it may be worthwhile to experiment with different templates depending on
  150. the task setting.
  151. multi_label (`bool`, *optional*, defaults to `False`):
  152. Whether or not multiple candidate labels can be true. If `False`, the scores are normalized such that
  153. the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
  154. independent and probabilities are normalized for each candidate by doing a softmax of the entailment
  155. score vs. the contradiction score.
  156. Return:
  157. A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
  158. - **sequence** (`str`) -- The sequence for which this is the output.
  159. - **labels** (`list[str]`) -- The labels sorted by order of likelihood.
  160. - **scores** (`list[float]`) -- The probabilities for each of the labels.
  161. """
  162. if len(args) == 0:
  163. pass
  164. elif len(args) == 1 and "candidate_labels" not in kwargs:
  165. kwargs["candidate_labels"] = args[0]
  166. else:
  167. raise ValueError(f"Unable to understand extra arguments {args}")
  168. return super().__call__(sequences, **kwargs)
  169. def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
  170. sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
  171. for i, (candidate_label, sequence_pair) in enumerate(zip(candidate_labels, sequence_pairs)):
  172. model_input = self._parse_and_tokenize([sequence_pair])
  173. yield {
  174. "candidate_label": candidate_label,
  175. "sequence": sequences[0],
  176. "is_last": i == len(candidate_labels) - 1,
  177. **model_input,
  178. }
  179. def _forward(self, inputs):
  180. candidate_label = inputs["candidate_label"]
  181. sequence = inputs["sequence"]
  182. model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
  183. # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
  184. model_forward = self.model.forward
  185. if "use_cache" in inspect.signature(model_forward).parameters:
  186. model_inputs["use_cache"] = False
  187. outputs = self.model(**model_inputs)
  188. model_outputs = {
  189. "candidate_label": candidate_label,
  190. "sequence": sequence,
  191. "is_last": inputs["is_last"],
  192. **outputs,
  193. }
  194. return model_outputs
  195. def postprocess(self, model_outputs, multi_label=False):
  196. candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
  197. sequences = [outputs["sequence"] for outputs in model_outputs]
  198. logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs])
  199. N = logits.shape[0]
  200. n = len(candidate_labels)
  201. num_sequences = N // n
  202. reshaped_outputs = logits.reshape((num_sequences, n, -1))
  203. if multi_label or len(candidate_labels) == 1:
  204. # softmax over the entailment vs. contradiction dim for each label independently
  205. entailment_id = self.entailment_id
  206. contradiction_id = -1 if entailment_id == 0 else 0
  207. entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
  208. scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
  209. scores = scores[..., 1]
  210. else:
  211. # softmax the "entailment" logits over all candidate labels
  212. entail_logits = reshaped_outputs[..., self.entailment_id]
  213. scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
  214. top_inds = list(reversed(scores[0].argsort()))
  215. return {
  216. "sequence": sequences[0],
  217. "labels": [candidate_labels[i] for i in top_inds],
  218. "scores": scores[0, top_inds].tolist(),
  219. }