__init__.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from functools import lru_cache
  16. from packaging import version
  17. from .. import __version__
  18. from .auto_docstring import (
  19. ClassAttrs,
  20. ClassDocstring,
  21. ImageProcessorArgs,
  22. ModelArgs,
  23. ModelOutputArgs,
  24. auto_class_docstring,
  25. auto_docstring,
  26. get_args_doc_from_source,
  27. parse_docstring,
  28. set_min_indent,
  29. )
  30. from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
  31. from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
  32. from .doc import (
  33. add_code_sample_docstrings,
  34. add_end_docstrings,
  35. add_start_docstrings,
  36. add_start_docstrings_to_model_forward,
  37. copy_func,
  38. replace_return_docstrings,
  39. )
  40. from .generic import (
  41. ContextManagers,
  42. ExplicitEnum,
  43. ModelOutput,
  44. PaddingStrategy,
  45. TensorType,
  46. TransformersKwargs,
  47. _is_tensor_or_array_like,
  48. can_return_loss,
  49. can_return_tuple,
  50. expand_dims,
  51. filter_out_non_signature_kwargs,
  52. find_labels,
  53. flatten_dict,
  54. is_numpy_array,
  55. is_tensor,
  56. is_timm_config_dict,
  57. is_timm_local_checkpoint,
  58. is_torch_device,
  59. is_torch_dtype,
  60. is_torch_tensor,
  61. reshape,
  62. safe_load_json_file,
  63. squeeze,
  64. strtobool,
  65. tensor_size,
  66. to_numpy,
  67. to_py_obj,
  68. torch_float,
  69. torch_int,
  70. transpose,
  71. )
  72. from .hub import (
  73. CHAT_TEMPLATE_DIR,
  74. CHAT_TEMPLATE_FILE,
  75. CLOUDFRONT_DISTRIB_PREFIX,
  76. HF_MODULES_CACHE,
  77. LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
  78. S3_BUCKET_PREFIX,
  79. TRANSFORMERS_DYNAMIC_MODULE_NAME,
  80. EntryNotFoundError,
  81. PushInProgress,
  82. PushToHubMixin,
  83. RepositoryNotFoundError,
  84. RevisionNotFoundError,
  85. cached_file,
  86. define_sagemaker_information,
  87. extract_commit_hash,
  88. has_file,
  89. http_user_agent,
  90. list_repo_templates,
  91. try_to_load_from_cache,
  92. )
  93. from .import_utils import (
  94. ACCELERATE_MIN_VERSION,
  95. BITSANDBYTES_MIN_VERSION,
  96. ENV_VARS_TRUE_AND_AUTO_VALUES,
  97. ENV_VARS_TRUE_VALUES,
  98. GGUF_MIN_VERSION,
  99. TRITON_MIN_VERSION,
  100. XLA_FSDPV2_MIN_VERSION,
  101. DummyObject,
  102. OptionalDependencyNotAvailable,
  103. _LazyModule,
  104. check_torch_load_is_safe,
  105. direct_transformers_import,
  106. enable_tf32,
  107. get_torch_version,
  108. is_accelerate_available,
  109. is_apex_available,
  110. is_apollo_torch_available,
  111. is_aqlm_available,
  112. is_auto_round_available,
  113. is_av_available,
  114. is_bitsandbytes_available,
  115. is_bs4_available,
  116. is_coloredlogs_available,
  117. is_compressed_tensors_available,
  118. is_cuda_platform,
  119. is_cv2_available,
  120. is_cython_available,
  121. is_datasets_available,
  122. is_decord_available,
  123. is_detectron2_available,
  124. is_env_variable_false,
  125. is_env_variable_true,
  126. is_essentia_available,
  127. is_faiss_available,
  128. is_fbgemm_gpu_available,
  129. is_flash_attn_2_available,
  130. is_flash_attn_3_available,
  131. is_flash_attn_4_available,
  132. is_flash_attn_greater_or_equal,
  133. is_flash_attn_greater_or_equal_2_10,
  134. is_flute_available,
  135. is_fouroversix_available,
  136. is_fp_quant_available,
  137. is_fsdp_available,
  138. is_g2p_en_available,
  139. is_galore_torch_available,
  140. is_gguf_available,
  141. is_gptqmodel_available,
  142. is_grokadamw_available,
  143. is_grouped_mm_available,
  144. is_habana_gaudi1,
  145. is_hadamard_available,
  146. is_hqq_available,
  147. is_huggingface_hub_greater_or_equal,
  148. is_in_notebook,
  149. is_jinja_available,
  150. is_jmespath_available,
  151. is_jumanpp_available,
  152. is_kenlm_available,
  153. is_kernels_available,
  154. is_levenshtein_available,
  155. is_libcst_available,
  156. is_librosa_available,
  157. is_liger_kernel_available,
  158. is_llm_awq_available,
  159. is_lomo_available,
  160. is_matplotlib_available,
  161. is_mistral_common_available,
  162. is_mlx_available,
  163. is_multipart_available,
  164. is_natten_available,
  165. is_ninja_available,
  166. is_nltk_available,
  167. is_num2words_available,
  168. is_numba_available,
  169. is_onnx_available,
  170. is_openai_available,
  171. is_optimum_available,
  172. is_optimum_quanto_available,
  173. is_pandas_available,
  174. is_peft_available,
  175. is_phonemizer_available,
  176. is_pretty_midi_available,
  177. is_protobuf_available,
  178. is_psutil_available,
  179. is_py3nvml_available,
  180. is_pyctcdecode_available,
  181. is_pytesseract_available,
  182. is_pytest_available,
  183. is_pytest_order_available,
  184. is_pytorch_quantization_available,
  185. is_quanto_greater,
  186. is_quark_available,
  187. is_qutlass_available,
  188. is_rich_available,
  189. is_rjieba_available,
  190. is_rocm_platform,
  191. is_sacremoses_available,
  192. is_sagemaker_dp_enabled,
  193. is_sagemaker_mp_enabled,
  194. is_schedulefree_available,
  195. is_scipy_available,
  196. is_sentencepiece_available,
  197. is_seqio_available,
  198. is_serve_available,
  199. is_sinq_available,
  200. is_sklearn_available,
  201. is_soundfile_available,
  202. is_spacy_available,
  203. is_speech_available,
  204. is_spqr_available,
  205. is_sudachi_available,
  206. is_sudachi_projection_available,
  207. is_tiktoken_available,
  208. is_timm_available,
  209. is_tokenizers_available,
  210. is_torch_accelerator_available,
  211. is_torch_available,
  212. is_torch_bf16_available_on_device,
  213. is_torch_bf16_gpu_available,
  214. is_torch_cuda_available,
  215. is_torch_deterministic,
  216. is_torch_flex_attn_available,
  217. is_torch_fp16_available_on_device,
  218. is_torch_fx_proxy,
  219. is_torch_greater_or_equal,
  220. is_torch_hpu_available,
  221. is_torch_mlu_available,
  222. is_torch_mps_available,
  223. is_torch_musa_available,
  224. is_torch_neuron_available,
  225. is_torch_neuroncore_available,
  226. is_torch_npu_available,
  227. is_torch_optimi_available,
  228. is_torch_tensorrt_fx_available,
  229. is_torch_tf32_available,
  230. is_torch_xla_available,
  231. is_torch_xpu_available,
  232. is_torchao_available,
  233. is_torchaudio_available,
  234. is_torchcodec_available,
  235. is_torchdistx_available,
  236. is_torchdynamo_compiling,
  237. is_torchdynamo_exporting,
  238. is_torchvision_available,
  239. is_torchvision_v2_available,
  240. is_tracing,
  241. is_training_run_on_sagemaker,
  242. is_triton_available,
  243. is_uroman_available,
  244. is_vision_available,
  245. is_vptq_available,
  246. is_xlstm_available,
  247. is_yt_dlp_available,
  248. requires_backends,
  249. torch_compilable_check,
  250. torch_only_method,
  251. )
  252. from .kernel_config import KernelConfig
  253. from .peft_utils import (
  254. ADAPTER_CONFIG_NAME,
  255. ADAPTER_SAFE_WEIGHTS_NAME,
  256. ADAPTER_WEIGHTS_NAME,
  257. check_peft_version,
  258. find_adapter_config_file,
  259. )
  260. WEIGHTS_NAME = "pytorch_model.bin"
  261. WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
  262. SAFE_WEIGHTS_NAME = "model.safetensors"
  263. SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
  264. CONFIG_NAME = "config.json"
  265. FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
  266. IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
  267. VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
  268. AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json"
  269. PROCESSOR_NAME = "processor_config.json"
  270. GENERATION_CONFIG_NAME = "generation_config.json"
  271. MODEL_CARD_NAME = "modelcard.json"
  272. SENTENCEPIECE_UNDERLINE = "▁"
  273. SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
  274. MULTIPLE_CHOICE_DUMMY_INPUTS = [
  275. [[0, 1, 0, 1], [1, 0, 0, 1]]
  276. ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
  277. DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
  278. DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
  279. def check_min_version(min_version):
  280. if version.parse(__version__) < version.parse(min_version):
  281. if "dev" in min_version:
  282. error_message = (
  283. "This example requires a source install from HuggingFace Transformers (see "
  284. "`https://huggingface.co/docs/transformers/installation#install-from-source`),"
  285. )
  286. else:
  287. error_message = f"This example requires a minimum version of {min_version},"
  288. error_message += f" but the version found is {__version__}.\n"
  289. raise ImportError(
  290. error_message
  291. + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other "
  292. "versions of HuggingFace Transformers."
  293. )
  294. @lru_cache
  295. def get_available_devices() -> frozenset[str]:
  296. """
  297. Returns a frozenset of devices available for the current PyTorch installation.
  298. """
  299. devices = {"cpu"} # `cpu` is always supported as a device in PyTorch
  300. if is_torch_cuda_available():
  301. devices.add("cuda")
  302. if is_torch_mps_available():
  303. devices.add("mps")
  304. if is_torch_xpu_available():
  305. devices.add("xpu")
  306. if is_torch_npu_available():
  307. devices.add("npu")
  308. if is_torch_hpu_available():
  309. devices.add("hpu")
  310. if is_torch_mlu_available():
  311. devices.add("mlu")
  312. if is_torch_musa_available():
  313. devices.add("musa")
  314. return frozenset(devices)