doc.py 36 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Doc utilities: Utilities related to documentation
  16. """
  17. import functools
  18. import inspect
  19. import re
  20. import textwrap
  21. import types
  22. from collections import OrderedDict
  23. from typing import cast
  24. def get_docstring_indentation_level(func):
  25. """Return the indentation level of the start of the docstring of a class or function (or method)."""
  26. # We assume classes are always defined in the global scope
  27. if inspect.isclass(func):
  28. return 4
  29. source = inspect.getsource(func)
  30. first_line = source.splitlines()[0]
  31. function_def_level = len(first_line) - len(first_line.lstrip())
  32. return 4 + function_def_level
  33. def add_start_docstrings(*docstr):
  34. def docstring_decorator(fn):
  35. fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
  36. return fn
  37. return docstring_decorator
  38. def add_start_docstrings_to_model_forward(*docstr):
  39. def docstring_decorator(fn):
  40. class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
  41. intro = rf""" The {class_name} forward method, overrides the `__call__` special method.
  42. <Tip>
  43. Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
  44. instance afterwards instead of this since the former takes care of running the pre and post processing steps while
  45. the latter silently ignores them.
  46. </Tip>
  47. """
  48. correct_indentation = get_docstring_indentation_level(fn)
  49. current_doc = fn.__doc__ if fn.__doc__ is not None else ""
  50. try:
  51. first_non_empty = next(line for line in current_doc.splitlines() if line.strip() != "")
  52. doc_indentation = len(first_non_empty) - len(first_non_empty.lstrip())
  53. except StopIteration:
  54. doc_indentation = correct_indentation
  55. docs = docstr
  56. # In this case, the correct indentation level (class method, 2 Python levels) was respected, and we should
  57. # correctly reindent everything. Otherwise, the doc uses a single indentation level
  58. if doc_indentation == 4 + correct_indentation:
  59. docs = [textwrap.indent(textwrap.dedent(doc), " " * correct_indentation) for doc in docstr]
  60. intro = textwrap.indent(textwrap.dedent(intro), " " * correct_indentation)
  61. docstring = "".join(docs) + current_doc
  62. fn.__doc__ = intro + docstring
  63. return fn
  64. return docstring_decorator
  65. def add_end_docstrings(*docstr):
  66. def docstring_decorator(fn):
  67. fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
  68. return fn
  69. return docstring_decorator
  70. PT_RETURN_INTRODUCTION = r"""
  71. Returns:
  72. [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
  73. `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
  74. elements depending on the configuration ([`{config_class}`]) and inputs.
  75. """
  76. def _get_indent(t):
  77. """Returns the indentation in the first line of t"""
  78. search = re.search(r"^(\s*)\S", t)
  79. return "" if search is None else search.groups()[0]
  80. def _convert_output_args_doc(output_args_doc):
  81. """Convert output_args_doc to display properly."""
  82. # Split output_arg_doc in blocks argument/description
  83. indent = _get_indent(output_args_doc)
  84. blocks = []
  85. current_block = ""
  86. for line in output_args_doc.split("\n"):
  87. # If the indent is the same as the beginning, the line is the name of new arg.
  88. if _get_indent(line) == indent:
  89. if len(current_block) > 0:
  90. blocks.append(current_block[:-1])
  91. current_block = f"{line}\n"
  92. else:
  93. # Otherwise it's part of the description of the current arg.
  94. # We need to remove 2 spaces to the indentation.
  95. current_block += f"{line[2:]}\n"
  96. blocks.append(current_block[:-1])
  97. # Format each block for proper rendering
  98. for i in range(len(blocks)):
  99. blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
  100. blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
  101. return "\n".join(blocks)
  102. def _prepare_output_docstrings(output_type, config_class, min_indent=None, add_intro=True):
  103. """
  104. Prepares the return part of the docstring using `output_type`.
  105. """
  106. output_docstring = output_type.__doc__
  107. params_docstring = None
  108. if output_docstring is not None:
  109. # Remove the head of the docstring to keep the list of args only
  110. lines = output_docstring.split("\n")
  111. i = 0
  112. while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
  113. i += 1
  114. if i < len(lines):
  115. params_docstring = "\n".join(lines[(i + 1) :])
  116. params_docstring = _convert_output_args_doc(params_docstring)
  117. elif add_intro:
  118. raise ValueError(
  119. f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
  120. "docstring and contain either `Args` or `Parameters`."
  121. )
  122. # Add the return introduction
  123. if add_intro:
  124. full_output_type = f"{output_type.__module__}.{output_type.__name__}"
  125. intro = PT_RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class)
  126. else:
  127. full_output_type = str(output_type)
  128. intro = f"\nReturns:\n `{full_output_type}`"
  129. if params_docstring is not None:
  130. intro += ":\n"
  131. result = intro
  132. if params_docstring is not None:
  133. result += params_docstring
  134. # Apply minimum indent if necessary
  135. if min_indent is not None:
  136. lines = result.split("\n")
  137. # Find the indent of the first nonempty line
  138. i = 0
  139. while len(lines[i]) == 0:
  140. i += 1
  141. indent = len(_get_indent(lines[i]))
  142. # If too small, add indentation to all nonempty lines
  143. if indent < min_indent:
  144. to_add = " " * (min_indent - indent)
  145. lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
  146. result = "\n".join(lines)
  147. return result
  148. FAKE_MODEL_DISCLAIMER = """
  149. <Tip warning={true}>
  150. This example uses a random model as the real ones are all very big. To get proper results, you should use
  151. {real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can try
  152. adding `device_map="auto"` in the `from_pretrained` call.
  153. </Tip>
  154. """
  155. PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
  156. Example:
  157. ```python
  158. >>> from transformers import AutoTokenizer, {model_class}
  159. >>> import torch
  160. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  161. >>> model = {model_class}.from_pretrained("{checkpoint}")
  162. >>> inputs = tokenizer(
  163. ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
  164. ... )
  165. >>> with torch.no_grad():
  166. ... logits = model(**inputs).logits
  167. >>> predicted_token_class_ids = logits.argmax(-1)
  168. >>> # Note that tokens are classified rather then input words which means that
  169. >>> # there might be more predicted token classes than words.
  170. >>> # Multiple token classes might account for the same word
  171. >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
  172. >>> predicted_tokens_classes
  173. {expected_output}
  174. >>> labels = predicted_token_class_ids
  175. >>> loss = model(**inputs, labels=labels).loss
  176. >>> round(loss.item(), 2)
  177. {expected_loss}
  178. ```
  179. """
  180. PT_QUESTION_ANSWERING_SAMPLE = r"""
  181. Example:
  182. ```python
  183. >>> from transformers import AutoTokenizer, {model_class}
  184. >>> import torch
  185. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  186. >>> model = {model_class}.from_pretrained("{checkpoint}")
  187. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  188. >>> inputs = tokenizer(question, text, return_tensors="pt")
  189. >>> with torch.no_grad():
  190. ... outputs = model(**inputs)
  191. >>> answer_start_index = outputs.start_logits.argmax()
  192. >>> answer_end_index = outputs.end_logits.argmax()
  193. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  194. >>> tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
  195. {expected_output}
  196. >>> # target is "nice puppet"
  197. >>> target_start_index = torch.tensor([{qa_target_start_index}])
  198. >>> target_end_index = torch.tensor([{qa_target_end_index}])
  199. >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
  200. >>> loss = outputs.loss
  201. >>> round(loss.item(), 2)
  202. {expected_loss}
  203. ```
  204. """
  205. PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  206. Example of single-label classification:
  207. ```python
  208. >>> import torch
  209. >>> from transformers import AutoTokenizer, {model_class}
  210. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  211. >>> model = {model_class}.from_pretrained("{checkpoint}")
  212. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  213. >>> with torch.no_grad():
  214. ... logits = model(**inputs).logits
  215. >>> predicted_class_id = logits.argmax().item()
  216. >>> model.config.id2label[predicted_class_id]
  217. {expected_output}
  218. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  219. >>> num_labels = len(model.config.id2label)
  220. >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
  221. >>> labels = torch.tensor([1])
  222. >>> loss = model(**inputs, labels=labels).loss
  223. >>> round(loss.item(), 2)
  224. {expected_loss}
  225. ```
  226. Example of multi-label classification:
  227. ```python
  228. >>> import torch
  229. >>> from transformers import AutoTokenizer, {model_class}
  230. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  231. >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
  232. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  233. >>> with torch.no_grad():
  234. ... logits = model(**inputs).logits
  235. >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
  236. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  237. >>> num_labels = len(model.config.id2label)
  238. >>> model = {model_class}.from_pretrained(
  239. ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
  240. ... )
  241. >>> labels = torch.sum(
  242. ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
  243. ... ).to(torch.float)
  244. >>> loss = model(**inputs, labels=labels).loss
  245. ```
  246. """
  247. PT_MASKED_LM_SAMPLE = r"""
  248. Example:
  249. ```python
  250. >>> from transformers import AutoTokenizer, {model_class}
  251. >>> import torch
  252. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  253. >>> model = {model_class}.from_pretrained("{checkpoint}")
  254. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
  255. >>> with torch.no_grad():
  256. ... logits = model(**inputs).logits
  257. >>> # retrieve index of {mask}
  258. >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
  259. >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
  260. >>> tokenizer.decode(predicted_token_id)
  261. {expected_output}
  262. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
  263. >>> # mask labels of non-{mask} tokens
  264. >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  265. >>> outputs = model(**inputs, labels=labels)
  266. >>> round(outputs.loss.item(), 2)
  267. {expected_loss}
  268. ```
  269. """
  270. PT_BASE_MODEL_SAMPLE = r"""
  271. Example:
  272. ```python
  273. >>> from transformers import AutoTokenizer, {model_class}
  274. >>> import torch
  275. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  276. >>> model = {model_class}.from_pretrained("{checkpoint}")
  277. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  278. >>> outputs = model(**inputs)
  279. >>> last_hidden_states = outputs.last_hidden_state
  280. ```
  281. """
  282. PT_MULTIPLE_CHOICE_SAMPLE = r"""
  283. Example:
  284. ```python
  285. >>> from transformers import AutoTokenizer, {model_class}
  286. >>> import torch
  287. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  288. >>> model = {model_class}.from_pretrained("{checkpoint}")
  289. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  290. >>> choice0 = "It is eaten with a fork and a knife."
  291. >>> choice1 = "It is eaten while held in the hand."
  292. >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
  293. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
  294. >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
  295. >>> # the linear classifier still needs to be trained
  296. >>> loss = outputs.loss
  297. >>> logits = outputs.logits
  298. ```
  299. """
  300. PT_CAUSAL_LM_SAMPLE = r"""
  301. Example:
  302. ```python
  303. >>> import torch
  304. >>> from transformers import AutoTokenizer, {model_class}
  305. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  306. >>> model = {model_class}.from_pretrained("{checkpoint}")
  307. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  308. >>> outputs = model(**inputs, labels=inputs["input_ids"])
  309. >>> loss = outputs.loss
  310. >>> logits = outputs.logits
  311. ```
  312. """
  313. PT_SPEECH_BASE_MODEL_SAMPLE = r"""
  314. Example:
  315. ```python
  316. >>> from transformers import AutoProcessor, {model_class}
  317. >>> import torch
  318. >>> from datasets import load_dataset
  319. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  320. >>> dataset = dataset.sort("id")
  321. >>> sampling_rate = dataset.features["audio"].sampling_rate
  322. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  323. >>> model = {model_class}.from_pretrained("{checkpoint}")
  324. >>> # audio file is decoded on the fly
  325. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  326. >>> with torch.no_grad():
  327. ... outputs = model(**inputs)
  328. >>> last_hidden_states = outputs.last_hidden_state
  329. >>> list(last_hidden_states.shape)
  330. {expected_output}
  331. ```
  332. """
  333. PT_SPEECH_CTC_SAMPLE = r"""
  334. Example:
  335. ```python
  336. >>> from transformers import AutoProcessor, {model_class}
  337. >>> from datasets import load_dataset
  338. >>> import torch
  339. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  340. >>> dataset = dataset.sort("id")
  341. >>> sampling_rate = dataset.features["audio"].sampling_rate
  342. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  343. >>> model = {model_class}.from_pretrained("{checkpoint}")
  344. >>> # audio file is decoded on the fly
  345. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  346. >>> with torch.no_grad():
  347. ... logits = model(**inputs).logits
  348. >>> predicted_ids = torch.argmax(logits, dim=-1)
  349. >>> # transcribe speech
  350. >>> transcription = processor.batch_decode(predicted_ids)
  351. >>> transcription[0]
  352. {expected_output}
  353. >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
  354. >>> # compute loss
  355. >>> loss = model(**inputs).loss
  356. >>> round(loss.item(), 2)
  357. {expected_loss}
  358. ```
  359. """
  360. PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
  361. Example:
  362. ```python
  363. >>> from transformers import AutoFeatureExtractor, {model_class}
  364. >>> from datasets import load_dataset
  365. >>> import torch
  366. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  367. >>> dataset = dataset.sort("id")
  368. >>> sampling_rate = dataset.features["audio"].sampling_rate
  369. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  370. >>> model = {model_class}.from_pretrained("{checkpoint}")
  371. >>> # audio file is decoded on the fly
  372. >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  373. >>> with torch.no_grad():
  374. ... logits = model(**inputs).logits
  375. >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
  376. >>> predicted_label = model.config.id2label[predicted_class_ids]
  377. >>> predicted_label
  378. {expected_output}
  379. >>> # compute loss - target_label is e.g. "down"
  380. >>> target_label = model.config.id2label[0]
  381. >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
  382. >>> loss = model(**inputs).loss
  383. >>> round(loss.item(), 2)
  384. {expected_loss}
  385. ```
  386. """
  387. PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
  388. Example:
  389. ```python
  390. >>> from transformers import AutoFeatureExtractor, {model_class}
  391. >>> from datasets import load_dataset
  392. >>> import torch
  393. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  394. >>> dataset = dataset.sort("id")
  395. >>> sampling_rate = dataset.features["audio"].sampling_rate
  396. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  397. >>> model = {model_class}.from_pretrained("{checkpoint}")
  398. >>> # audio file is decoded on the fly
  399. >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
  400. >>> with torch.no_grad():
  401. ... logits = model(**inputs).logits
  402. >>> probabilities = torch.sigmoid(logits[0])
  403. >>> # labels is a one-hot array of shape (num_frames, num_speakers)
  404. >>> labels = (probabilities > 0.5).long()
  405. >>> labels[0].tolist()
  406. {expected_output}
  407. ```
  408. """
  409. PT_SPEECH_XVECTOR_SAMPLE = r"""
  410. Example:
  411. ```python
  412. >>> from transformers import AutoFeatureExtractor, {model_class}
  413. >>> from datasets import load_dataset
  414. >>> import torch
  415. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
  416. >>> dataset = dataset.sort("id")
  417. >>> sampling_rate = dataset.features["audio"].sampling_rate
  418. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  419. >>> model = {model_class}.from_pretrained("{checkpoint}")
  420. >>> # audio file is decoded on the fly
  421. >>> inputs = feature_extractor(
  422. ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
  423. ... )
  424. >>> with torch.no_grad():
  425. ... embeddings = model(**inputs).embeddings
  426. >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
  427. >>> # the resulting embeddings can be used for cosine similarity-based retrieval
  428. >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
  429. >>> similarity = cosine_sim(embeddings[0], embeddings[1])
  430. >>> threshold = 0.7 # the optimal threshold is dataset-dependent
  431. >>> if similarity < threshold:
  432. ... print("Speakers are not the same!")
  433. >>> round(similarity.item(), 2)
  434. {expected_output}
  435. ```
  436. """
  437. PT_VISION_BASE_MODEL_SAMPLE = r"""
  438. Example:
  439. ```python
  440. >>> from transformers import AutoImageProcessor, {model_class}
  441. >>> import torch
  442. >>> from datasets import load_dataset
  443. >>> dataset = load_dataset("huggingface/cats-image")
  444. >>> image = dataset["test"]["image"][0]
  445. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  446. >>> model = {model_class}.from_pretrained("{checkpoint}")
  447. >>> inputs = image_processor(image, return_tensors="pt")
  448. >>> with torch.no_grad():
  449. ... outputs = model(**inputs)
  450. >>> last_hidden_states = outputs.last_hidden_state
  451. >>> list(last_hidden_states.shape)
  452. {expected_output}
  453. ```
  454. """
  455. PT_VISION_SEQ_CLASS_SAMPLE = r"""
  456. Example:
  457. ```python
  458. >>> from transformers import AutoImageProcessor, {model_class}
  459. >>> import torch
  460. >>> from datasets import load_dataset
  461. >>> dataset = load_dataset("huggingface/cats-image")
  462. >>> image = dataset["test"]["image"][0]
  463. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  464. >>> model = {model_class}.from_pretrained("{checkpoint}")
  465. >>> inputs = image_processor(image, return_tensors="pt")
  466. >>> with torch.no_grad():
  467. ... logits = model(**inputs).logits
  468. >>> # model predicts one of the 1000 ImageNet classes
  469. >>> predicted_label = logits.argmax(-1).item()
  470. >>> print(model.config.id2label[predicted_label])
  471. {expected_output}
  472. ```
  473. """
  474. PT_SAMPLE_DOCSTRINGS = {
  475. "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
  476. "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
  477. "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
  478. "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
  479. "MaskedLM": PT_MASKED_LM_SAMPLE,
  480. "LMHead": PT_CAUSAL_LM_SAMPLE,
  481. "BaseModel": PT_BASE_MODEL_SAMPLE,
  482. "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
  483. "CTC": PT_SPEECH_CTC_SAMPLE,
  484. "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
  485. "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
  486. "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
  487. "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
  488. "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
  489. }
  490. TEXT_TO_AUDIO_SPECTROGRAM_SAMPLE = r"""
  491. Example:
  492. ```python
  493. >>> from transformers import AutoProcessor, {model_class}, SpeechT5HifiGan
  494. >>> model = {model_class}.from_pretrained("{checkpoint}")
  495. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  496. >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
  497. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  498. >>> # generate speech
  499. >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder)
  500. ```
  501. """
  502. TEXT_TO_AUDIO_WAVEFORM_SAMPLE = r"""
  503. Example:
  504. ```python
  505. >>> from transformers import AutoProcessor, {model_class}
  506. >>> model = {model_class}.from_pretrained("{checkpoint}")
  507. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  508. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  509. >>> # generate speech
  510. >>> speech = model(inputs["input_ids"])
  511. ```
  512. """
  513. AUDIO_FRAME_CLASSIFICATION_SAMPLE = PT_SPEECH_FRAME_CLASS_SAMPLE
  514. AUDIO_XVECTOR_SAMPLE = PT_SPEECH_XVECTOR_SAMPLE
  515. DEPTH_ESTIMATION_SAMPLE = r"""
  516. Example:
  517. ```python
  518. >>> from transformers import AutoImageProcessor, {model_class}
  519. >>> import torch
  520. >>> from PIL import Image
  521. >>> import httpx
  522. >>> from io import BytesIO
  523. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  524. >>> with httpx.stream("GET", url) as response:
  525. ... image = Image.open(BytesIO(response.read())).convert("RGB")
  526. >>> processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  527. >>> model = {model_class}.from_pretrained("{checkpoint}")
  528. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  529. >>> model.to(device)
  530. >>> # prepare image for the model
  531. >>> inputs = processor(images=image, return_tensors="pt").to(device)
  532. >>> with torch.no_grad():
  533. ... outputs = model(**inputs)
  534. >>> # interpolate to original size
  535. >>> post_processed_output = processor.post_process_depth_estimation(
  536. ... outputs, [(image.height, image.width)],
  537. ... )
  538. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  539. ```
  540. """
  541. VIDEO_CLASSIFICATION_SAMPLE = r"""
  542. Example:
  543. ```python
  544. ```
  545. """
  546. ZERO_SHOT_OBJECT_DETECTION_SAMPLE = r"""
  547. Example:
  548. ```python
  549. ```
  550. """
  551. IMAGE_TO_IMAGE_SAMPLE = r"""
  552. Example:
  553. ```python
  554. ```
  555. """
  556. IMAGE_FEATURE_EXTRACTION_SAMPLE = r"""
  557. Example:
  558. ```python
  559. ```
  560. """
  561. DOCUMENT_QUESTION_ANSWERING_SAMPLE = r"""
  562. Example:
  563. ```python
  564. ```
  565. """
  566. NEXT_SENTENCE_PREDICTION_SAMPLE = r"""
  567. Example:
  568. ```python
  569. ```
  570. """
  571. MULTIPLE_CHOICE_SAMPLE = PT_MULTIPLE_CHOICE_SAMPLE
  572. PRETRAINING_SAMPLE = r"""
  573. Example:
  574. ```python
  575. ```
  576. """
  577. MASK_GENERATION_SAMPLE = r"""
  578. Example:
  579. ```python
  580. ```
  581. """
  582. VISUAL_QUESTION_ANSWERING_SAMPLE = r"""
  583. Example:
  584. ```python
  585. ```
  586. """
  587. TEXT_GENERATION_SAMPLE = r"""
  588. Example:
  589. ```python
  590. ```
  591. """
  592. IMAGE_CLASSIFICATION_SAMPLE = PT_VISION_SEQ_CLASS_SAMPLE
  593. IMAGE_SEGMENTATION_SAMPLE = r"""
  594. Example:
  595. ```python
  596. ```
  597. """
  598. FILL_MASK_SAMPLE = r"""
  599. Example:
  600. ```python
  601. ```
  602. """
  603. OBJECT_DETECTION_SAMPLE = r"""
  604. Example:
  605. ```python
  606. ```
  607. """
  608. QUESTION_ANSWERING_SAMPLE = PT_QUESTION_ANSWERING_SAMPLE
  609. TEXT_CLASSIFICATION_SAMPLE = PT_SEQUENCE_CLASSIFICATION_SAMPLE
  610. TABLE_QUESTION_ANSWERING_SAMPLE = r"""
  611. Example:
  612. ```python
  613. ```
  614. """
  615. TOKEN_CLASSIFICATION_SAMPLE = PT_TOKEN_CLASSIFICATION_SAMPLE
  616. AUDIO_CLASSIFICATION_SAMPLE = PT_SPEECH_SEQ_CLASS_SAMPLE
  617. AUTOMATIC_SPEECH_RECOGNITION_SAMPLE = PT_SPEECH_CTC_SAMPLE
  618. ZERO_SHOT_IMAGE_CLASSIFICATION_SAMPLE = r"""
  619. Example:
  620. ```python
  621. ```
  622. """
  623. IMAGE_TEXT_TO_TEXT_GENERATION_SAMPLE = r"""
  624. Example:
  625. ```python
  626. >>> from PIL import Image
  627. >>> from transformers import AutoProcessor, {model_class}
  628. >>> model = {model_class}.from_pretrained("{checkpoint}")
  629. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  630. >>> messages = [
  631. ... {{
  632. ... "role": "user", "content": [
  633. ... {{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"}},
  634. ... {{"type": "text", "text": "Where is the cat standing?"}},
  635. ... ]
  636. ... }},
  637. ... ]
  638. >>> inputs = processor.apply_chat_template(
  639. ... messages,
  640. ... tokenize=True,
  641. ... return_dict=True,
  642. ... return_tensors="pt",
  643. ... add_generation_prompt=True
  644. ... )
  645. >>> # Generate
  646. >>> generate_ids = model.generate(**inputs)
  647. >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
  648. ```
  649. """
  650. PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS = OrderedDict(
  651. [
  652. ("text-to-audio-spectrogram", TEXT_TO_AUDIO_SPECTROGRAM_SAMPLE),
  653. ("text-to-audio-waveform", TEXT_TO_AUDIO_WAVEFORM_SAMPLE),
  654. ("automatic-speech-recognition", AUTOMATIC_SPEECH_RECOGNITION_SAMPLE),
  655. ("audio-frame-classification", AUDIO_FRAME_CLASSIFICATION_SAMPLE),
  656. ("audio-classification", AUDIO_CLASSIFICATION_SAMPLE),
  657. ("audio-xvector", AUDIO_XVECTOR_SAMPLE),
  658. ("image-text-to-text", IMAGE_TEXT_TO_TEXT_GENERATION_SAMPLE),
  659. ("depth-estimation", DEPTH_ESTIMATION_SAMPLE),
  660. ("video-classification", VIDEO_CLASSIFICATION_SAMPLE),
  661. ("zero-shot-image-classification", ZERO_SHOT_IMAGE_CLASSIFICATION_SAMPLE),
  662. ("image-classification", IMAGE_CLASSIFICATION_SAMPLE),
  663. ("zero-shot-object-detection", ZERO_SHOT_OBJECT_DETECTION_SAMPLE),
  664. ("object-detection", OBJECT_DETECTION_SAMPLE),
  665. ("image-segmentation", IMAGE_SEGMENTATION_SAMPLE),
  666. ("image-feature-extraction", IMAGE_FEATURE_EXTRACTION_SAMPLE),
  667. ("text-generation", TEXT_GENERATION_SAMPLE),
  668. ("table-question-answering", TABLE_QUESTION_ANSWERING_SAMPLE),
  669. ("document-question-answering", DOCUMENT_QUESTION_ANSWERING_SAMPLE),
  670. ("next-sentence-prediction", NEXT_SENTENCE_PREDICTION_SAMPLE),
  671. ("multiple-choice", MULTIPLE_CHOICE_SAMPLE),
  672. ("text-classification", TEXT_CLASSIFICATION_SAMPLE),
  673. ("token-classification", TOKEN_CLASSIFICATION_SAMPLE),
  674. ("fill-mask", FILL_MASK_SAMPLE),
  675. ("mask-generation", MASK_GENERATION_SAMPLE),
  676. ("pretraining", PRETRAINING_SAMPLE),
  677. ]
  678. )
  679. # Ordered dict to look for more specialized model mappings first
  680. # before falling back to the more generic ones.
  681. MODELS_TO_PIPELINE = OrderedDict(
  682. [
  683. # Audio
  684. ("MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", "text-to-audio-spectrogram"),
  685. ("MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "text-to-audio-waveform"),
  686. ("MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "automatic-speech-recognition"),
  687. ("MODEL_FOR_CTC_MAPPING_NAMES", "automatic-speech-recognition"),
  688. ("MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES", "audio-frame-classification"),
  689. ("MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", "audio-classification"),
  690. ("MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "audio-xvector"),
  691. # Vision
  692. ("MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "image-text-to-text"),
  693. ("MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "depth-estimation"),
  694. ("MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "video-classification"),
  695. ("MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", "zero-shot-image-classification"),
  696. ("MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "image-classification"),
  697. ("MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES", "zero-shot-object-detection"),
  698. ("MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "object-detection"),
  699. ("MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "image-segmentation"),
  700. ("MODEL_FOR_IMAGE_MAPPING_NAMES", "image-feature-extraction"),
  701. # Text/tokens
  702. ("MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "text-generation"),
  703. ("MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES", "table-question-answering"),
  704. ("MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES", "document-question-answering"),
  705. ("MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES", "next-sentence-prediction"),
  706. ("MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES", "multiple-choice"),
  707. ("MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "text-classification"),
  708. ("MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", "token-classification"),
  709. ("MODEL_FOR_MASKED_LM_MAPPING_NAMES", "fill-mask"),
  710. ("MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "mask-generation"),
  711. ("MODEL_FOR_PRETRAINING_MAPPING_NAMES", "pretraining"),
  712. ]
  713. )
  714. def filter_outputs_from_example(docstring, **kwargs):
  715. """
  716. Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
  717. """
  718. for key, value in kwargs.items():
  719. if value is not None:
  720. continue
  721. doc_key = "{" + key + "}"
  722. docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
  723. return docstring
  724. def add_code_sample_docstrings(
  725. *docstr,
  726. processor_class=None,
  727. checkpoint=None,
  728. output_type=None,
  729. config_class=None,
  730. mask="[MASK]",
  731. qa_target_start_index=14,
  732. qa_target_end_index=15,
  733. model_cls=None,
  734. modality=None,
  735. expected_output=None,
  736. expected_loss=None,
  737. real_checkpoint=None,
  738. revision=None,
  739. ):
  740. def docstring_decorator(fn):
  741. # model_class defaults to function's class if not specified otherwise
  742. model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
  743. sample_docstrings = PT_SAMPLE_DOCSTRINGS
  744. # putting all kwargs for docstrings in a dict to be used
  745. # with the `.format(**doc_kwargs)`. Note that string might
  746. # be formatted with non-existing keys, which is fine.
  747. doc_kwargs = {
  748. "model_class": model_class,
  749. "processor_class": processor_class,
  750. "checkpoint": checkpoint,
  751. "mask": mask,
  752. "qa_target_start_index": qa_target_start_index,
  753. "qa_target_end_index": qa_target_end_index,
  754. "expected_output": expected_output,
  755. "expected_loss": expected_loss,
  756. "real_checkpoint": real_checkpoint,
  757. "fake_checkpoint": checkpoint,
  758. "true": "{true}", # For <Tip warning={true}> syntax that conflicts with formatting.
  759. }
  760. if ("SequenceClassification" in model_class or "AudioClassification" in model_class) and modality == "audio":
  761. code_sample = sample_docstrings["AudioClassification"]
  762. elif "SequenceClassification" in model_class:
  763. code_sample = sample_docstrings["SequenceClassification"]
  764. elif "QuestionAnswering" in model_class:
  765. code_sample = sample_docstrings["QuestionAnswering"]
  766. elif "TokenClassification" in model_class:
  767. code_sample = sample_docstrings["TokenClassification"]
  768. elif "MultipleChoice" in model_class:
  769. code_sample = sample_docstrings["MultipleChoice"]
  770. elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
  771. code_sample = sample_docstrings["MaskedLM"]
  772. elif "LMHead" in model_class or "CausalLM" in model_class:
  773. code_sample = sample_docstrings["LMHead"]
  774. elif "CTC" in model_class:
  775. code_sample = sample_docstrings["CTC"]
  776. elif "AudioFrameClassification" in model_class:
  777. code_sample = sample_docstrings["AudioFrameClassification"]
  778. elif "XVector" in model_class and modality == "audio":
  779. code_sample = sample_docstrings["AudioXVector"]
  780. elif "Model" in model_class and modality == "audio":
  781. code_sample = sample_docstrings["SpeechBaseModel"]
  782. elif "Model" in model_class and modality == "vision":
  783. code_sample = sample_docstrings["VisionBaseModel"]
  784. elif "Model" in model_class or "Encoder" in model_class:
  785. code_sample = sample_docstrings["BaseModel"]
  786. elif "ImageClassification" in model_class:
  787. code_sample = sample_docstrings["ImageClassification"]
  788. else:
  789. raise ValueError(f"Docstring can't be built for model {model_class}")
  790. code_sample = filter_outputs_from_example(
  791. code_sample, expected_output=expected_output, expected_loss=expected_loss
  792. )
  793. if real_checkpoint is not None:
  794. code_sample = FAKE_MODEL_DISCLAIMER + code_sample
  795. func_doc = (fn.__doc__ or "") + "".join(docstr)
  796. output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
  797. built_doc = code_sample.format(**doc_kwargs)
  798. if revision is not None:
  799. if re.match(r"^refs/pr/\\d+", revision):
  800. raise ValueError(
  801. f"The provided revision '{revision}' is incorrect. It should point to"
  802. " a pull request reference on the hub like 'refs/pr/6'"
  803. )
  804. built_doc = built_doc.replace(
  805. f'from_pretrained("{checkpoint}")', f'from_pretrained("{checkpoint}", revision="{revision}")'
  806. )
  807. fn.__doc__ = func_doc + output_doc + built_doc
  808. return fn
  809. return docstring_decorator
  810. def replace_return_docstrings(output_type=None, config_class=None):
  811. def docstring_decorator(fn):
  812. func_doc = fn.__doc__
  813. lines = func_doc.split("\n")
  814. i = 0
  815. while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
  816. i += 1
  817. if i < len(lines):
  818. indent = len(_get_indent(lines[i]))
  819. lines[i] = _prepare_output_docstrings(output_type, config_class, min_indent=indent)
  820. func_doc = "\n".join(lines)
  821. else:
  822. raise ValueError(
  823. f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
  824. f"current docstring is:\n{func_doc}"
  825. )
  826. fn.__doc__ = func_doc
  827. return fn
  828. return docstring_decorator
  829. def copy_func(f):
  830. """Returns a copy of a function f."""
  831. # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
  832. g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
  833. g = cast(types.FunctionType, functools.update_wrapper(g, f))
  834. g.__kwdefaults__ = f.__kwdefaults__
  835. return g