__init__.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and
  15. # once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are
  16. # only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used
  17. # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
  18. # in the namespace without actually importing anything (and especially none of the backends).
  19. __version__ = "5.5.4"
  20. import importlib
  21. import sys
  22. import types
  23. from pathlib import Path
  24. from typing import TYPE_CHECKING
  25. # Check the dependencies satisfy the minimal versions required.
  26. from . import dependency_versions_check
  27. from .utils import (
  28. OptionalDependencyNotAvailable,
  29. _LazyModule,
  30. is_essentia_available,
  31. is_g2p_en_available,
  32. is_librosa_available,
  33. is_mistral_common_available,
  34. is_mlx_available,
  35. is_numba_available,
  36. is_pretty_midi_available,
  37. )
  38. # Note: the following symbols are deliberately exported with `as`
  39. # so that mypy, pylint or other static linters can recognize them,
  40. # given that they are not exported using `__all__` in this file.
  41. from .utils import is_bitsandbytes_available as is_bitsandbytes_available
  42. from .utils import is_scipy_available as is_scipy_available
  43. from .utils import is_sentencepiece_available as is_sentencepiece_available
  44. from .utils import is_speech_available as is_speech_available
  45. from .utils import is_timm_available as is_timm_available
  46. from .utils import is_tokenizers_available as is_tokenizers_available
  47. from .utils import is_torch_available as is_torch_available
  48. from .utils import is_torchaudio_available as is_torchaudio_available
  49. from .utils import is_torchvision_available as is_torchvision_available
  50. from .utils import is_vision_available as is_vision_available
  51. from .utils import logging as logging
  52. from .utils.import_utils import define_import_structure
  53. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  54. # Base objects, independent of any specific backend
  55. _import_structure = {
  56. "audio_utils": [],
  57. "cli": [],
  58. "configuration_utils": ["PreTrainedConfig", "PretrainedConfig"],
  59. "convert_slow_tokenizers_checkpoints_to_fast": [],
  60. "data": [
  61. "DataProcessor",
  62. "InputExample",
  63. "InputFeatures",
  64. "SingleSentenceClassificationProcessor",
  65. "SquadExample",
  66. "SquadFeatures",
  67. "SquadV1Processor",
  68. "SquadV2Processor",
  69. "glue_compute_metrics",
  70. "glue_convert_examples_to_features",
  71. "glue_output_modes",
  72. "glue_processors",
  73. "glue_tasks_num_labels",
  74. "squad_convert_examples_to_features",
  75. "xnli_compute_metrics",
  76. "xnli_output_modes",
  77. "xnli_processors",
  78. "xnli_tasks_num_labels",
  79. ],
  80. "data.data_collator": [
  81. "DataCollator",
  82. "DataCollatorForLanguageModeling",
  83. "DataCollatorForMultipleChoice",
  84. "DataCollatorForPermutationLanguageModeling",
  85. "DataCollatorForSeq2Seq",
  86. "DataCollatorForSOP",
  87. "DataCollatorForTokenClassification",
  88. "DataCollatorForWholeWordMask",
  89. "DataCollatorWithFlattening",
  90. "DataCollatorWithPadding",
  91. "DefaultDataCollator",
  92. "default_data_collator",
  93. ],
  94. "data.metrics": [],
  95. "data.processors": [],
  96. "debug_utils": [],
  97. "dependency_versions_check": [],
  98. "dependency_versions_table": [],
  99. "dynamic_module_utils": [],
  100. "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
  101. "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
  102. "file_utils": [],
  103. "generation": [
  104. "AsyncTextIteratorStreamer",
  105. "CompileConfig",
  106. "ContinuousBatchingConfig",
  107. "GenerationConfig",
  108. "TextIteratorStreamer",
  109. "TextStreamer",
  110. "WatermarkingConfig",
  111. ],
  112. "hf_argparser": ["HfArgumentParser"],
  113. "hyperparameter_search": [],
  114. "image_processing_utils_fast": [],
  115. "image_transforms": [],
  116. "integrations": [
  117. "is_clearml_available",
  118. "is_comet_available",
  119. "is_dvclive_available",
  120. "is_neptune_available",
  121. "is_optuna_available",
  122. "is_ray_available",
  123. "is_ray_tune_available",
  124. "is_swanlab_available",
  125. "is_tensorboard_available",
  126. "is_trackio_available",
  127. "is_wandb_available",
  128. ],
  129. "loss": [],
  130. "pipelines": [
  131. "AnyToAnyPipeline",
  132. "AudioClassificationPipeline",
  133. "AutomaticSpeechRecognitionPipeline",
  134. "CsvPipelineDataFormat",
  135. "DepthEstimationPipeline",
  136. "DocumentQuestionAnsweringPipeline",
  137. "FeatureExtractionPipeline",
  138. "FillMaskPipeline",
  139. "ImageClassificationPipeline",
  140. "ImageFeatureExtractionPipeline",
  141. "ImageSegmentationPipeline",
  142. "ImageTextToTextPipeline",
  143. "JsonPipelineDataFormat",
  144. "KeypointMatchingPipeline",
  145. "MaskGenerationPipeline",
  146. "NerPipeline",
  147. "ObjectDetectionPipeline",
  148. "PipedPipelineDataFormat",
  149. "Pipeline",
  150. "PipelineDataFormat",
  151. "TableQuestionAnsweringPipeline",
  152. "TextClassificationPipeline",
  153. "TextGenerationPipeline",
  154. "TextToAudioPipeline",
  155. "TokenClassificationPipeline",
  156. "VideoClassificationPipeline",
  157. "ZeroShotAudioClassificationPipeline",
  158. "ZeroShotClassificationPipeline",
  159. "ZeroShotImageClassificationPipeline",
  160. "ZeroShotObjectDetectionPipeline",
  161. "pipeline",
  162. ],
  163. "processing_utils": [
  164. "AudioKwargs",
  165. "ImagesKwargs",
  166. "ProcessingKwargs",
  167. "ProcessorMixin",
  168. "TextKwargs",
  169. "VideosKwargs",
  170. ],
  171. "quantizers": [],
  172. "testing_utils": [],
  173. "tokenization_python": ["PreTrainedTokenizer", "PythonBackend"],
  174. "tokenization_utils": [],
  175. "tokenization_utils_base": [
  176. "AddedToken",
  177. "BatchEncoding",
  178. "CharSpan",
  179. "PreTrainedTokenizerBase",
  180. "TokenSpan",
  181. ],
  182. "tokenization_utils_fast": [],
  183. "tokenization_utils_sentencepiece": ["SentencePieceBackend"],
  184. "trainer_callback": [
  185. "DefaultFlowCallback",
  186. "EarlyStoppingCallback",
  187. "PrinterCallback",
  188. "ProgressCallback",
  189. "TrainerCallback",
  190. "TrainerControl",
  191. "TrainerState",
  192. ],
  193. "trainer_utils": [
  194. "EvalPrediction",
  195. "IntervalStrategy",
  196. "SchedulerType",
  197. "enable_full_determinism",
  198. "set_seed",
  199. ],
  200. "training_args": ["TrainingArguments"],
  201. "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
  202. "utils": [
  203. "CONFIG_NAME",
  204. "MODEL_CARD_NAME",
  205. "SPIECE_UNDERLINE",
  206. "WEIGHTS_NAME",
  207. "TensorType",
  208. "add_end_docstrings",
  209. "add_start_docstrings",
  210. "is_apex_available",
  211. "is_av_available",
  212. "is_bitsandbytes_available",
  213. "is_datasets_available",
  214. "is_faiss_available",
  215. "is_matplotlib_available",
  216. "is_mlx_available",
  217. "is_phonemizer_available",
  218. "is_psutil_available",
  219. "is_py3nvml_available",
  220. "is_pyctcdecode_available",
  221. "is_sacremoses_available",
  222. "is_scipy_available",
  223. "is_sentencepiece_available",
  224. "is_sklearn_available",
  225. "is_speech_available",
  226. "is_timm_available",
  227. "is_tokenizers_available",
  228. "is_torch_available",
  229. "is_torch_hpu_available",
  230. "is_torch_mlu_available",
  231. "is_torch_musa_available",
  232. "is_torch_neuroncore_available",
  233. "is_torch_npu_available",
  234. "is_torchvision_available",
  235. "is_torch_xla_available",
  236. "is_torch_xpu_available",
  237. "is_vision_available",
  238. "logging",
  239. ],
  240. "utils.import_utils": ["requires_backends"],
  241. "utils.kernel_config": ["KernelConfig"],
  242. "utils.quantization_config": [
  243. "AqlmConfig",
  244. "AutoRoundConfig",
  245. "AwqConfig",
  246. "BitNetQuantConfig",
  247. "BitsAndBytesConfig",
  248. "CompressedTensorsConfig",
  249. "EetqConfig",
  250. "FbgemmFp8Config",
  251. "FineGrainedFP8Config",
  252. "FourOverSixConfig",
  253. "FPQuantConfig",
  254. "GPTQConfig",
  255. "HiggsConfig",
  256. "HqqConfig",
  257. "MetalConfig",
  258. "Mxfp4Config",
  259. "QuantoConfig",
  260. "QuarkConfig",
  261. "SinqConfig",
  262. "SpQRConfig",
  263. "TorchAoConfig",
  264. "VptqConfig",
  265. ],
  266. "video_utils": [],
  267. }
  268. # tokenizers-backed objects
  269. try:
  270. if not is_tokenizers_available():
  271. raise OptionalDependencyNotAvailable()
  272. except OptionalDependencyNotAvailable:
  273. from .utils import dummy_tokenizers_objects
  274. _import_structure["utils.dummy_tokenizers_objects"] = [
  275. name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
  276. ]
  277. else:
  278. # Fast tokenizers structure
  279. _import_structure["tokenization_utils_tokenizers"] = [
  280. "PreTrainedTokenizerFast",
  281. "TokenizersBackend",
  282. ]
  283. try:
  284. if not (is_sentencepiece_available() and is_tokenizers_available()):
  285. raise OptionalDependencyNotAvailable()
  286. except OptionalDependencyNotAvailable:
  287. from .utils import dummy_sentencepiece_and_tokenizers_objects
  288. _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
  289. name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
  290. ]
  291. else:
  292. _import_structure["convert_slow_tokenizer"] = [
  293. "SLOW_TO_FAST_CONVERTERS",
  294. "convert_slow_tokenizer",
  295. ]
  296. try:
  297. if not (is_mistral_common_available()):
  298. raise OptionalDependencyNotAvailable()
  299. except OptionalDependencyNotAvailable:
  300. from .utils import dummy_mistral_common_objects
  301. _import_structure["utils.dummy_mistral_common_objects"] = [
  302. name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
  303. ]
  304. else:
  305. _import_structure["tokenization_mistral_common"] = ["MistralCommonBackend"]
  306. # Vision-specific objects
  307. try:
  308. if not is_vision_available():
  309. raise OptionalDependencyNotAvailable()
  310. except OptionalDependencyNotAvailable:
  311. from .utils import dummy_vision_objects
  312. _import_structure["utils.dummy_vision_objects"] = [
  313. name for name in dir(dummy_vision_objects) if not name.startswith("_")
  314. ]
  315. else:
  316. _import_structure["image_processing_backends"] = ["PilBackend"]
  317. _import_structure["image_processing_base"] = ["ImageProcessingMixin"]
  318. _import_structure["image_processing_utils"] = ["BaseImageProcessor"]
  319. _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
  320. try:
  321. if not is_torchvision_available():
  322. raise OptionalDependencyNotAvailable()
  323. except OptionalDependencyNotAvailable:
  324. from .utils import dummy_torchvision_objects
  325. _import_structure["utils.dummy_torchvision_objects"] = [
  326. name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
  327. ]
  328. else:
  329. _import_structure.setdefault("image_processing_backends", [])
  330. _import_structure["image_processing_backends"] += ["TorchvisionBackend"]
  331. _import_structure["video_processing_utils"] = ["BaseVideoProcessor"]
  332. # PyTorch-backed objects
  333. try:
  334. if not is_torch_available():
  335. raise OptionalDependencyNotAvailable()
  336. except OptionalDependencyNotAvailable:
  337. from .utils import dummy_pt_objects
  338. _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
  339. else:
  340. _import_structure["model_debugging_utils"] = [
  341. "model_addition_debugger_context",
  342. ]
  343. _import_structure["activations"] = []
  344. _import_structure["cache_utils"] = [
  345. "CacheLayerMixin",
  346. "DynamicLayer",
  347. "StaticLayer",
  348. "StaticSlidingWindowLayer",
  349. "QuantoQuantizedLayer",
  350. "HQQQuantizedLayer",
  351. "Cache",
  352. "DynamicCache",
  353. "EncoderDecoderCache",
  354. "QuantizedCache",
  355. "StaticCache",
  356. ]
  357. _import_structure["data.datasets"] = [
  358. "GlueDataset",
  359. "GlueDataTrainingArguments",
  360. "SquadDataset",
  361. "SquadDataTrainingArguments",
  362. ]
  363. _import_structure["generation"].extend(
  364. [
  365. "AlternatingCodebooksLogitsProcessor",
  366. "BayesianDetectorConfig",
  367. "BayesianDetectorModel",
  368. "ClassifierFreeGuidanceLogitsProcessor",
  369. "ContinuousBatchingManager",
  370. "ContinuousMixin",
  371. "EncoderNoRepeatNGramLogitsProcessor",
  372. "EncoderRepetitionPenaltyLogitsProcessor",
  373. "EosTokenCriteria",
  374. "EpsilonLogitsWarper",
  375. "MinPLogitsWarper",
  376. "EtaLogitsWarper",
  377. "ExponentialDecayLengthPenalty",
  378. "ForcedBOSTokenLogitsProcessor",
  379. "ForcedEOSTokenLogitsProcessor",
  380. "GenerationMixin",
  381. "InfNanRemoveLogitsProcessor",
  382. "LogitNormalization",
  383. "LogitsProcessor",
  384. "LogitsProcessorList",
  385. "MaxLengthCriteria",
  386. "MaxTimeCriteria",
  387. "MinLengthLogitsProcessor",
  388. "MinNewTokensLengthLogitsProcessor",
  389. "NoBadWordsLogitsProcessor",
  390. "NoRepeatNGramLogitsProcessor",
  391. "PrefixConstrainedLogitsProcessor",
  392. "RepetitionPenaltyLogitsProcessor",
  393. "SequenceBiasLogitsProcessor",
  394. "StoppingCriteria",
  395. "StoppingCriteriaList",
  396. "StopStringCriteria",
  397. "SuppressTokensAtBeginLogitsProcessor",
  398. "SuppressTokensLogitsProcessor",
  399. "SynthIDTextWatermarkDetector",
  400. "SynthIDTextWatermarkingConfig",
  401. "SynthIDTextWatermarkLogitsProcessor",
  402. "TemperatureLogitsWarper",
  403. "TopHLogitsWarper",
  404. "TopKLogitsWarper",
  405. "TopPLogitsWarper",
  406. "TypicalLogitsWarper",
  407. "UnbatchedClassifierFreeGuidanceLogitsProcessor",
  408. "WatermarkDetector",
  409. "WatermarkLogitsProcessor",
  410. "WhisperTimeStampLogitsProcessor",
  411. ]
  412. )
  413. # PyTorch domain libraries integration
  414. _import_structure["integrations.executorch"] = [
  415. "TorchExportableModuleWithStaticCache",
  416. "convert_and_export_with_cache",
  417. ]
  418. _import_structure["core_model_loading"] = [
  419. "Chunk",
  420. "Concatenate",
  421. "ConversionOps",
  422. "MergeModulelist",
  423. "PermuteForRope",
  424. "SplitModulelist",
  425. "WeightConverter",
  426. ]
  427. _import_structure["modeling_flash_attention_utils"] = []
  428. _import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
  429. _import_structure["modeling_outputs"] = []
  430. _import_structure["backbone_utils"] = ["BackboneConfigMixin", "BackboneMixin"]
  431. _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update", "RopeParameters"]
  432. _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
  433. _import_structure["masking_utils"] = ["AttentionMaskInterface"]
  434. _import_structure["optimization"] = [
  435. "Adafactor",
  436. "get_constant_schedule",
  437. "get_constant_schedule_with_warmup",
  438. "get_cosine_schedule_with_warmup",
  439. "get_cosine_with_hard_restarts_schedule_with_warmup",
  440. "get_cosine_with_min_lr_schedule_with_warmup",
  441. "get_cosine_with_min_lr_schedule_with_warmup_lr_rate",
  442. "get_greedy_schedule",
  443. "get_inverse_sqrt_schedule",
  444. "get_linear_schedule_with_warmup",
  445. "get_polynomial_decay_schedule_with_warmup",
  446. "get_reduce_on_plateau_schedule",
  447. "get_scheduler",
  448. "get_wsd_schedule",
  449. "GreedyLR",
  450. ]
  451. _import_structure["pytorch_utils"] = ["Conv1D", "apply_chunking_to_forward"]
  452. _import_structure["time_series_utils"] = []
  453. _import_structure["trainer"] = ["Trainer"]
  454. _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
  455. _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
  456. # Direct imports for type-checking
  457. if TYPE_CHECKING:
  458. # All modeling imports
  459. # Models
  460. from .backbone_utils import BackboneConfigMixin, BackboneMixin
  461. from .cache_utils import Cache as Cache
  462. from .cache_utils import DynamicCache as DynamicCache
  463. from .cache_utils import DynamicLayer as DynamicLayer
  464. from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
  465. from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer
  466. from .cache_utils import QuantizedCache as QuantizedCache
  467. from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer
  468. from .cache_utils import StaticCache as StaticCache
  469. from .cache_utils import StaticLayer as StaticLayer
  470. from .cache_utils import StaticSlidingWindowLayer as StaticSlidingWindowLayer
  471. from .configuration_utils import PreTrainedConfig as PreTrainedConfig
  472. from .configuration_utils import PretrainedConfig as PretrainedConfig
  473. from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
  474. from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer
  475. from .core_model_loading import Chunk as Chunk
  476. from .core_model_loading import Concatenate as Concatenate
  477. from .core_model_loading import ConversionOps as ConversionOps
  478. from .core_model_loading import MergeModulelist as MergeModulelist
  479. from .core_model_loading import PermuteForRope as PermuteForRope
  480. from .core_model_loading import SplitModulelist as SplitModulelist
  481. from .core_model_loading import WeightConverter as WeightConverter
  482. # Data
  483. from .data import DataProcessor as DataProcessor
  484. from .data import InputExample as InputExample
  485. from .data import InputFeatures as InputFeatures
  486. from .data import SingleSentenceClassificationProcessor as SingleSentenceClassificationProcessor
  487. from .data import SquadExample as SquadExample
  488. from .data import SquadFeatures as SquadFeatures
  489. from .data import SquadV1Processor as SquadV1Processor
  490. from .data import SquadV2Processor as SquadV2Processor
  491. from .data import glue_compute_metrics as glue_compute_metrics
  492. from .data import glue_convert_examples_to_features as glue_convert_examples_to_features
  493. from .data import glue_output_modes as glue_output_modes
  494. from .data import glue_processors as glue_processors
  495. from .data import glue_tasks_num_labels as glue_tasks_num_labels
  496. from .data import squad_convert_examples_to_features as squad_convert_examples_to_features
  497. from .data import xnli_compute_metrics as xnli_compute_metrics
  498. from .data import xnli_output_modes as xnli_output_modes
  499. from .data import xnli_processors as xnli_processors
  500. from .data import xnli_tasks_num_labels as xnli_tasks_num_labels
  501. from .data.data_collator import DataCollator as DataCollator
  502. from .data.data_collator import DataCollatorForLanguageModeling as DataCollatorForLanguageModeling
  503. from .data.data_collator import DataCollatorForMultipleChoice as DataCollatorForMultipleChoice
  504. from .data.data_collator import (
  505. DataCollatorForPermutationLanguageModeling as DataCollatorForPermutationLanguageModeling,
  506. )
  507. from .data.data_collator import DataCollatorForSeq2Seq as DataCollatorForSeq2Seq
  508. from .data.data_collator import DataCollatorForSOP as DataCollatorForSOP
  509. from .data.data_collator import DataCollatorForTokenClassification as DataCollatorForTokenClassification
  510. from .data.data_collator import DataCollatorForWholeWordMask as DataCollatorForWholeWordMask
  511. from .data.data_collator import DataCollatorWithFlattening as DataCollatorWithFlattening
  512. from .data.data_collator import DataCollatorWithPadding as DataCollatorWithPadding
  513. from .data.data_collator import DefaultDataCollator as DefaultDataCollator
  514. from .data.data_collator import default_data_collator as default_data_collator
  515. from .data.datasets import GlueDataset as GlueDataset
  516. from .data.datasets import GlueDataTrainingArguments as GlueDataTrainingArguments
  517. from .data.datasets import SquadDataset as SquadDataset
  518. from .data.datasets import SquadDataTrainingArguments as SquadDataTrainingArguments
  519. from .feature_extraction_sequence_utils import SequenceFeatureExtractor as SequenceFeatureExtractor
  520. # Feature Extractor
  521. from .feature_extraction_utils import BatchFeature as BatchFeature
  522. from .feature_extraction_utils import FeatureExtractionMixin as FeatureExtractionMixin
  523. # Generation
  524. from .generation import AlternatingCodebooksLogitsProcessor as AlternatingCodebooksLogitsProcessor
  525. from .generation import AsyncTextIteratorStreamer as AsyncTextIteratorStreamer
  526. from .generation import BayesianDetectorConfig as BayesianDetectorConfig
  527. from .generation import BayesianDetectorModel as BayesianDetectorModel
  528. from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor
  529. from .generation import CompileConfig as CompileConfig
  530. from .generation import ContinuousBatchingConfig as ContinuousBatchingConfig
  531. from .generation import ContinuousBatchingManager as ContinuousBatchingManager
  532. from .generation import ContinuousMixin as ContinuousMixin
  533. from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor
  534. from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor
  535. from .generation import EosTokenCriteria as EosTokenCriteria
  536. from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper
  537. from .generation import EtaLogitsWarper as EtaLogitsWarper
  538. from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty
  539. from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor
  540. from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor
  541. from .generation import GenerationConfig as GenerationConfig
  542. from .generation import GenerationMixin as GenerationMixin
  543. from .generation import InfNanRemoveLogitsProcessor as InfNanRemoveLogitsProcessor
  544. from .generation import LogitNormalization as LogitNormalization
  545. from .generation import LogitsProcessor as LogitsProcessor
  546. from .generation import LogitsProcessorList as LogitsProcessorList
  547. from .generation import MaxLengthCriteria as MaxLengthCriteria
  548. from .generation import MaxTimeCriteria as MaxTimeCriteria
  549. from .generation import MinLengthLogitsProcessor as MinLengthLogitsProcessor
  550. from .generation import MinNewTokensLengthLogitsProcessor as MinNewTokensLengthLogitsProcessor
  551. from .generation import MinPLogitsWarper as MinPLogitsWarper
  552. from .generation import NoBadWordsLogitsProcessor as NoBadWordsLogitsProcessor
  553. from .generation import NoRepeatNGramLogitsProcessor as NoRepeatNGramLogitsProcessor
  554. from .generation import PrefixConstrainedLogitsProcessor as PrefixConstrainedLogitsProcessor
  555. from .generation import RepetitionPenaltyLogitsProcessor as RepetitionPenaltyLogitsProcessor
  556. from .generation import SequenceBiasLogitsProcessor as SequenceBiasLogitsProcessor
  557. from .generation import StoppingCriteria as StoppingCriteria
  558. from .generation import StoppingCriteriaList as StoppingCriteriaList
  559. from .generation import StopStringCriteria as StopStringCriteria
  560. from .generation import SuppressTokensAtBeginLogitsProcessor as SuppressTokensAtBeginLogitsProcessor
  561. from .generation import SuppressTokensLogitsProcessor as SuppressTokensLogitsProcessor
  562. from .generation import SynthIDTextWatermarkDetector as SynthIDTextWatermarkDetector
  563. from .generation import SynthIDTextWatermarkingConfig as SynthIDTextWatermarkingConfig
  564. from .generation import SynthIDTextWatermarkLogitsProcessor as SynthIDTextWatermarkLogitsProcessor
  565. from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper
  566. from .generation import TextIteratorStreamer as TextIteratorStreamer
  567. from .generation import TextStreamer as TextStreamer
  568. from .generation import TopHLogitsWarper as TopHLogitsWarper
  569. from .generation import TopKLogitsWarper as TopKLogitsWarper
  570. from .generation import TopPLogitsWarper as TopPLogitsWarper
  571. from .generation import TypicalLogitsWarper as TypicalLogitsWarper
  572. from .generation import (
  573. UnbatchedClassifierFreeGuidanceLogitsProcessor as UnbatchedClassifierFreeGuidanceLogitsProcessor,
  574. )
  575. from .generation import WatermarkDetector as WatermarkDetector
  576. from .generation import WatermarkingConfig as WatermarkingConfig
  577. from .generation import WatermarkLogitsProcessor as WatermarkLogitsProcessor
  578. from .generation import WhisperTimeStampLogitsProcessor as WhisperTimeStampLogitsProcessor
  579. from .hf_argparser import HfArgumentParser as HfArgumentParser
  580. from .image_processing_backends import PilBackend as PilBackend
  581. from .image_processing_backends import TorchvisionBackend as TorchvisionBackend
  582. from .image_processing_base import ImageProcessingMixin as ImageProcessingMixin
  583. from .image_processing_utils import BaseImageProcessor as BaseImageProcessor
  584. from .image_utils import ImageFeatureExtractionMixin as ImageFeatureExtractionMixin
  585. # Integrations
  586. from .integrations import is_clearml_available as is_clearml_available
  587. from .integrations import is_comet_available as is_comet_available
  588. from .integrations import is_dvclive_available as is_dvclive_available
  589. from .integrations import is_neptune_available as is_neptune_available
  590. from .integrations import is_optuna_available as is_optuna_available
  591. from .integrations import is_ray_available as is_ray_available
  592. from .integrations import is_ray_tune_available as is_ray_tune_available
  593. from .integrations import is_swanlab_available as is_swanlab_available
  594. from .integrations import is_tensorboard_available as is_tensorboard_available
  595. from .integrations import is_trackio_available as is_trackio_available
  596. from .integrations import is_wandb_available as is_wandb_available
  597. from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache
  598. from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache
  599. from .masking_utils import AttentionMaskInterface as AttentionMaskInterface
  600. from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context
  601. from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer
  602. from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS
  603. from .modeling_rope_utils import RopeParameters as RopeParameters
  604. from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update
  605. from .modeling_utils import AttentionInterface as AttentionInterface
  606. from .modeling_utils import PreTrainedModel as PreTrainedModel
  607. from .models import *
  608. from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor
  609. # Optimization
  610. from .optimization import Adafactor as Adafactor
  611. from .optimization import GreedyLR as GreedyLR
  612. from .optimization import get_constant_schedule as get_constant_schedule
  613. from .optimization import get_constant_schedule_with_warmup as get_constant_schedule_with_warmup
  614. from .optimization import get_cosine_schedule_with_warmup as get_cosine_schedule_with_warmup
  615. from .optimization import (
  616. get_cosine_with_hard_restarts_schedule_with_warmup as get_cosine_with_hard_restarts_schedule_with_warmup,
  617. )
  618. from .optimization import (
  619. get_cosine_with_min_lr_schedule_with_warmup as get_cosine_with_min_lr_schedule_with_warmup,
  620. )
  621. from .optimization import (
  622. get_cosine_with_min_lr_schedule_with_warmup_lr_rate as get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
  623. )
  624. from .optimization import get_greedy_schedule as get_greedy_schedule
  625. from .optimization import get_inverse_sqrt_schedule as get_inverse_sqrt_schedule
  626. from .optimization import get_linear_schedule_with_warmup as get_linear_schedule_with_warmup
  627. from .optimization import get_polynomial_decay_schedule_with_warmup as get_polynomial_decay_schedule_with_warmup
  628. from .optimization import get_scheduler as get_scheduler
  629. from .optimization import get_wsd_schedule as get_wsd_schedule
  630. # Pipelines
  631. from .pipelines import AnyToAnyPipeline as AnyToAnyPipeline
  632. from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline
  633. from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline
  634. from .pipelines import CsvPipelineDataFormat as CsvPipelineDataFormat
  635. from .pipelines import DepthEstimationPipeline as DepthEstimationPipeline
  636. from .pipelines import DocumentQuestionAnsweringPipeline as DocumentQuestionAnsweringPipeline
  637. from .pipelines import FeatureExtractionPipeline as FeatureExtractionPipeline
  638. from .pipelines import FillMaskPipeline as FillMaskPipeline
  639. from .pipelines import ImageClassificationPipeline as ImageClassificationPipeline
  640. from .pipelines import ImageFeatureExtractionPipeline as ImageFeatureExtractionPipeline
  641. from .pipelines import ImageSegmentationPipeline as ImageSegmentationPipeline
  642. from .pipelines import ImageTextToTextPipeline as ImageTextToTextPipeline
  643. from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
  644. from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline
  645. from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
  646. from .pipelines import NerPipeline as NerPipeline
  647. from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline
  648. from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat
  649. from .pipelines import Pipeline as Pipeline
  650. from .pipelines import PipelineDataFormat as PipelineDataFormat
  651. from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline
  652. from .pipelines import TextClassificationPipeline as TextClassificationPipeline
  653. from .pipelines import TextGenerationPipeline as TextGenerationPipeline
  654. from .pipelines import TextToAudioPipeline as TextToAudioPipeline
  655. from .pipelines import TokenClassificationPipeline as TokenClassificationPipeline
  656. from .pipelines import VideoClassificationPipeline as VideoClassificationPipeline
  657. from .pipelines import ZeroShotAudioClassificationPipeline as ZeroShotAudioClassificationPipeline
  658. from .pipelines import ZeroShotClassificationPipeline as ZeroShotClassificationPipeline
  659. from .pipelines import ZeroShotImageClassificationPipeline as ZeroShotImageClassificationPipeline
  660. from .pipelines import ZeroShotObjectDetectionPipeline as ZeroShotObjectDetectionPipeline
  661. from .pipelines import pipeline as pipeline
  662. from .processing_utils import AudioKwargs as AudioKwargs
  663. from .processing_utils import ImagesKwargs as ImagesKwargs
  664. from .processing_utils import ProcessingKwargs as ProcessingKwargs
  665. from .processing_utils import ProcessorMixin as ProcessorMixin
  666. from .processing_utils import TextKwargs as TextKwargs
  667. from .processing_utils import VideosKwargs as VideosKwargs
  668. from .pytorch_utils import Conv1D as Conv1D
  669. from .pytorch_utils import apply_chunking_to_forward as apply_chunking_to_forward
  670. # Tokenization
  671. from .tokenization_python import PreTrainedTokenizer as PreTrainedTokenizer
  672. from .tokenization_python import PythonBackend as PythonBackend
  673. from .tokenization_utils_base import AddedToken as AddedToken
  674. from .tokenization_utils_base import BatchEncoding as BatchEncoding
  675. from .tokenization_utils_base import CharSpan as CharSpan
  676. from .tokenization_utils_base import PreTrainedTokenizerBase as PreTrainedTokenizerBase
  677. from .tokenization_utils_base import TokenSpan as TokenSpan
  678. # Tokenization
  679. from .tokenization_utils_sentencepiece import SentencePieceBackend as SentencePieceBackend
  680. from .tokenization_utils_tokenizers import PreTrainedTokenizerFast as PreTrainedTokenizerFast
  681. from .tokenization_utils_tokenizers import (
  682. TokenizersBackend as TokenizersBackend,
  683. )
  684. # Trainer
  685. from .trainer import Trainer as Trainer
  686. from .trainer_callback import DefaultFlowCallback as DefaultFlowCallback
  687. from .trainer_callback import EarlyStoppingCallback as EarlyStoppingCallback
  688. from .trainer_callback import PrinterCallback as PrinterCallback
  689. from .trainer_callback import ProgressCallback as ProgressCallback
  690. from .trainer_callback import TrainerCallback as TrainerCallback
  691. from .trainer_callback import TrainerControl as TrainerControl
  692. from .trainer_callback import TrainerState as TrainerState
  693. from .trainer_pt_utils import torch_distributed_zero_first as torch_distributed_zero_first
  694. from .trainer_seq2seq import Seq2SeqTrainer as Seq2SeqTrainer
  695. from .trainer_utils import EvalPrediction as EvalPrediction
  696. from .trainer_utils import IntervalStrategy as IntervalStrategy
  697. from .trainer_utils import SchedulerType as SchedulerType
  698. from .trainer_utils import enable_full_determinism as enable_full_determinism
  699. from .trainer_utils import set_seed as set_seed
  700. from .training_args import TrainingArguments as TrainingArguments
  701. from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments
  702. # Files and general utilities
  703. from .utils import CONFIG_NAME as CONFIG_NAME
  704. from .utils import MODEL_CARD_NAME as MODEL_CARD_NAME
  705. from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE
  706. from .utils import WEIGHTS_NAME as WEIGHTS_NAME
  707. from .utils import TensorType as TensorType
  708. from .utils import add_end_docstrings as add_end_docstrings
  709. from .utils import add_start_docstrings as add_start_docstrings
  710. from .utils import is_apex_available as is_apex_available
  711. from .utils import is_av_available as is_av_available
  712. from .utils import is_datasets_available as is_datasets_available
  713. from .utils import is_faiss_available as is_faiss_available
  714. from .utils import is_matplotlib_available as is_matplotlib_available
  715. from .utils import is_phonemizer_available as is_phonemizer_available
  716. from .utils import is_psutil_available as is_psutil_available
  717. from .utils import is_py3nvml_available as is_py3nvml_available
  718. from .utils import is_pyctcdecode_available as is_pyctcdecode_available
  719. from .utils import is_sacremoses_available as is_sacremoses_available
  720. from .utils import is_sklearn_available as is_sklearn_available
  721. from .utils import is_torch_hpu_available as is_torch_hpu_available
  722. from .utils import is_torch_mlu_available as is_torch_mlu_available
  723. from .utils import is_torch_musa_available as is_torch_musa_available
  724. from .utils import is_torch_neuroncore_available as is_torch_neuroncore_available
  725. from .utils import is_torch_npu_available as is_torch_npu_available
  726. from .utils import is_torch_xla_available as is_torch_xla_available
  727. from .utils import is_torch_xpu_available as is_torch_xpu_available
  728. from .utils.import_utils import requires_backends
  729. from .utils.kernel_config import KernelConfig as KernelConfig
  730. # Quantization config
  731. from .utils.quantization_config import AqlmConfig as AqlmConfig
  732. from .utils.quantization_config import AutoRoundConfig as AutoRoundConfig
  733. from .utils.quantization_config import AwqConfig as AwqConfig
  734. from .utils.quantization_config import BitNetQuantConfig as BitNetQuantConfig
  735. from .utils.quantization_config import BitsAndBytesConfig as BitsAndBytesConfig
  736. from .utils.quantization_config import CompressedTensorsConfig as CompressedTensorsConfig
  737. from .utils.quantization_config import EetqConfig as EetqConfig
  738. from .utils.quantization_config import FbgemmFp8Config as FbgemmFp8Config
  739. from .utils.quantization_config import FineGrainedFP8Config as FineGrainedFP8Config
  740. from .utils.quantization_config import FourOverSixConfig as FourOverSixConfig
  741. from .utils.quantization_config import FPQuantConfig as FPQuantConfig
  742. from .utils.quantization_config import GPTQConfig as GPTQConfig
  743. from .utils.quantization_config import HiggsConfig as HiggsConfig
  744. from .utils.quantization_config import HqqConfig as HqqConfig
  745. from .utils.quantization_config import MetalConfig as MetalConfig
  746. from .utils.quantization_config import QuantoConfig as QuantoConfig
  747. from .utils.quantization_config import QuarkConfig as QuarkConfig
  748. from .utils.quantization_config import SinqConfig as SinqConfig
  749. from .utils.quantization_config import SpQRConfig as SpQRConfig
  750. from .utils.quantization_config import TorchAoConfig as TorchAoConfig
  751. from .utils.quantization_config import VptqConfig as VptqConfig
  752. from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
  753. else:
  754. _import_structure = {k: set(v) for k, v in _import_structure.items()}
  755. import_structure = define_import_structure(Path(__file__).parent / "models", prefix="models")
  756. import_structure[frozenset({})].update(_import_structure)
  757. sys.modules[__name__] = _LazyModule(
  758. __name__,
  759. globals()["__file__"],
  760. import_structure,
  761. module_spec=__spec__,
  762. extra_objects={"__version__": __version__},
  763. )
  764. def _create_module_alias(alias: str, target: str) -> None:
  765. """
  766. Lazily redirect legacy module paths to their replacements without importing heavy deps.
  767. """
  768. module = types.ModuleType(alias)
  769. module.__doc__ = f"Alias module for backward compatibility with `{target}`."
  770. # Set __file__ explicitly so that inspect.py's hasattr(module, '__file__') check
  771. # never falls through to __getattr__ and triggers a premature (possibly circular) import.
  772. module.__file__ = None
  773. def _get_target():
  774. return importlib.import_module(target, __name__)
  775. module.__getattr__ = lambda name: getattr(_get_target(), name)
  776. module.__dir__ = lambda: dir(_get_target())
  777. sys.modules[alias] = module
  778. setattr(sys.modules[__name__], alias.rsplit(".", 1)[-1], module)
  779. _create_module_alias(f"{__name__}.tokenization_utils_fast", ".tokenization_utils_tokenizers")
  780. _create_module_alias(f"{__name__}.tokenization_utils", ".tokenization_utils_sentencepiece")
  781. _create_module_alias(f"{__name__}.image_processing_utils_fast", ".image_processing_backends")
  782. for _proc_file in sorted((Path(__file__).parent / "models").rglob("image_processing_*.py")):
  783. _model = _proc_file.parent.name
  784. _module = _proc_file.stem
  785. _target = f".models.{_model}.{_module}"
  786. _create_module_alias(f"{__name__}.models.{_model}.{_module}_fast", _target)
  787. # Also map XImageProcessorFast -> XImageProcessor for backward compat with old class names.
  788. def getattr_factory(target):
  789. def _getattr(name):
  790. new_name = name.removesuffix("Fast")
  791. logger.warning(
  792. "Accessing `%s` from `%s`. Returning `%s` instead. Behavior may be "
  793. "different and this alias will be removed in future versions.",
  794. name,
  795. target,
  796. new_name,
  797. )
  798. return getattr(importlib.import_module(target, __name__), new_name)
  799. return _getattr
  800. sys.modules[f"{__name__}.models.{_model}.{_module}_fast"].__getattr__ = getattr_factory(_target)
  801. if not is_torch_available():
  802. logger.warning_advice(
  803. "PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used."
  804. )