__init__.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_torch_greater_or_equal
  16. _import_structure = {
  17. "aqlm": ["replace_with_aqlm_linear"],
  18. "awq": [
  19. "post_init_awq_exllama_modules",
  20. "replace_quantization_scales",
  21. "replace_with_awq_linear",
  22. ],
  23. "bitnet": [
  24. "BitLinear",
  25. "pack_weights",
  26. "replace_with_bitnet_linear",
  27. "unpack_weights",
  28. ],
  29. "bitsandbytes": [
  30. "Bnb4bitQuantize",
  31. "dequantize_and_replace",
  32. "replace_with_bnb_linear",
  33. "validate_bnb_backend_availability",
  34. ],
  35. "deepspeed": [
  36. "HfDeepSpeedConfig",
  37. "HfTrainerDeepSpeedConfig",
  38. "deepspeed_config",
  39. "deepspeed_init",
  40. "deepspeed_load_checkpoint",
  41. "deepspeed_optim_sched",
  42. "is_deepspeed_available",
  43. "is_deepspeed_zero3_enabled",
  44. "set_hf_deepspeed_config",
  45. "unset_hf_deepspeed_config",
  46. ],
  47. "eetq": ["replace_with_eetq_linear"],
  48. "fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
  49. "finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
  50. "fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
  51. "ggml": [
  52. "GGUF_CONFIG_DEFAULTS_MAPPING",
  53. "GGUF_CONFIG_MAPPING",
  54. "GGUF_TOKENIZER_MAPPING",
  55. "_gguf_parse_value",
  56. "load_dequant_gguf_tensor",
  57. "load_gguf",
  58. ],
  59. "higgs": [
  60. "HiggsLinear",
  61. "dequantize_higgs",
  62. "quantize_with_higgs",
  63. "replace_with_higgs_linear",
  64. ],
  65. "hqq": ["prepare_for_hqq_linear"],
  66. "hub_kernels": [
  67. "LayerRepository",
  68. "lazy_load_kernel",
  69. "register_kernel_mapping",
  70. "replace_kernel_forward_from_hub",
  71. "use_kernel_forward_from_hub",
  72. "use_kernel_func_from_hub",
  73. "use_kernelized_func",
  74. ],
  75. "integration_utils": [
  76. "INTEGRATION_TO_CALLBACK",
  77. "AzureMLCallback",
  78. "ClearMLCallback",
  79. "CodeCarbonCallback",
  80. "CometCallback",
  81. "DagsHubCallback",
  82. "DVCLiveCallback",
  83. "FlyteCallback",
  84. "KubeflowCallback",
  85. "MLflowCallback",
  86. "NeptuneCallback",
  87. "NeptuneMissingConfiguration",
  88. "SwanLabCallback",
  89. "TensorBoardCallback",
  90. "TrackioCallback",
  91. "WandbCallback",
  92. "get_available_reporting_integrations",
  93. "get_reporting_integration_callbacks",
  94. "hp_params",
  95. "is_azureml_available",
  96. "is_clearml_available",
  97. "is_codecarbon_available",
  98. "is_comet_available",
  99. "is_dagshub_available",
  100. "is_dvclive_available",
  101. "is_flyte_deck_standard_available",
  102. "is_flytekit_available",
  103. "is_kubeflow_available",
  104. "is_mlflow_available",
  105. "is_neptune_available",
  106. "is_optuna_available",
  107. "is_ray_available",
  108. "is_ray_tune_available",
  109. "is_swanlab_available",
  110. "is_tensorboard_available",
  111. "is_trackio_available",
  112. "is_wandb_available",
  113. "rewrite_logs",
  114. "run_hp_search_optuna",
  115. "run_hp_search_ray",
  116. "run_hp_search_wandb",
  117. ],
  118. "liger": ["apply_liger_kernel"],
  119. "metal_quantization": [
  120. "MetalLinear",
  121. "replace_with_metal_linear",
  122. ],
  123. "moe": [
  124. "batched_mm_experts_forward",
  125. "grouped_mm_experts_forward",
  126. "use_experts_implementation",
  127. ],
  128. "mxfp4": [
  129. "Mxfp4GptOssExperts",
  130. "convert_moe_packed_tensors",
  131. "dequantize",
  132. "load_and_swizzle_mxfp4",
  133. "quantize_to_mxfp4",
  134. "replace_with_mxfp4_linear",
  135. "swizzle_mxfp4",
  136. ],
  137. "neftune": [
  138. "activate_neftune",
  139. "deactivate_neftune",
  140. "neftune_post_forward_hook",
  141. ],
  142. "peft": ["PeftAdapterMixin"],
  143. "quanto": ["replace_with_quanto_layers"],
  144. "sinq": ["SinqDeserialize", "SinqQuantize"],
  145. "spqr": ["replace_with_spqr_linear"],
  146. "vptq": ["replace_with_vptq_linear"],
  147. }
  148. try:
  149. if not is_torch_available():
  150. raise OptionalDependencyNotAvailable()
  151. except OptionalDependencyNotAvailable:
  152. pass
  153. else:
  154. _import_structure["executorch"] = [
  155. "TorchExportableModuleWithStaticCache",
  156. "convert_and_export_with_cache",
  157. ]
  158. _import_structure["tensor_parallel"] = [
  159. "shard_and_distribute_module",
  160. "ALL_PARALLEL_STYLES",
  161. "translate_to_torch_parallel_style",
  162. ]
  163. try:
  164. if not is_torch_greater_or_equal("2.5"):
  165. raise OptionalDependencyNotAvailable()
  166. except OptionalDependencyNotAvailable:
  167. pass
  168. else:
  169. _import_structure["flex_attention"] = [
  170. "make_flex_block_causal_mask",
  171. ]
  172. if TYPE_CHECKING:
  173. from .aqlm import replace_with_aqlm_linear
  174. from .awq import (
  175. post_init_awq_exllama_modules,
  176. replace_quantization_scales,
  177. replace_with_awq_linear,
  178. )
  179. from .bitnet import (
  180. BitLinear,
  181. pack_weights,
  182. replace_with_bitnet_linear,
  183. unpack_weights,
  184. )
  185. from .bitsandbytes import (
  186. Bnb4bitQuantize,
  187. dequantize_and_replace,
  188. replace_with_bnb_linear,
  189. validate_bnb_backend_availability,
  190. )
  191. from .deepspeed import (
  192. HfDeepSpeedConfig,
  193. HfTrainerDeepSpeedConfig,
  194. deepspeed_config,
  195. deepspeed_init,
  196. deepspeed_load_checkpoint,
  197. deepspeed_optim_sched,
  198. is_deepspeed_available,
  199. is_deepspeed_zero3_enabled,
  200. set_hf_deepspeed_config,
  201. unset_hf_deepspeed_config,
  202. )
  203. from .eetq import replace_with_eetq_linear
  204. from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
  205. from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
  206. from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
  207. from .ggml import (
  208. GGUF_CONFIG_DEFAULTS_MAPPING,
  209. GGUF_CONFIG_MAPPING,
  210. GGUF_TOKENIZER_MAPPING,
  211. _gguf_parse_value,
  212. load_dequant_gguf_tensor,
  213. load_gguf,
  214. )
  215. from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
  216. from .hqq import prepare_for_hqq_linear
  217. from .hub_kernels import (
  218. LayerRepository,
  219. lazy_load_kernel,
  220. register_kernel_mapping,
  221. replace_kernel_forward_from_hub,
  222. use_kernel_forward_from_hub,
  223. use_kernel_func_from_hub,
  224. use_kernelized_func,
  225. )
  226. from .integration_utils import (
  227. INTEGRATION_TO_CALLBACK,
  228. AzureMLCallback,
  229. ClearMLCallback,
  230. CodeCarbonCallback,
  231. CometCallback,
  232. DagsHubCallback,
  233. DVCLiveCallback,
  234. FlyteCallback,
  235. KubeflowCallback,
  236. MLflowCallback,
  237. NeptuneCallback,
  238. NeptuneMissingConfiguration,
  239. SwanLabCallback,
  240. TensorBoardCallback,
  241. TrackioCallback,
  242. WandbCallback,
  243. get_available_reporting_integrations,
  244. get_reporting_integration_callbacks,
  245. hp_params,
  246. is_azureml_available,
  247. is_clearml_available,
  248. is_codecarbon_available,
  249. is_comet_available,
  250. is_dagshub_available,
  251. is_dvclive_available,
  252. is_flyte_deck_standard_available,
  253. is_flytekit_available,
  254. is_kubeflow_available,
  255. is_mlflow_available,
  256. is_neptune_available,
  257. is_optuna_available,
  258. is_ray_available,
  259. is_ray_tune_available,
  260. is_swanlab_available,
  261. is_tensorboard_available,
  262. is_trackio_available,
  263. is_wandb_available,
  264. rewrite_logs,
  265. run_hp_search_optuna,
  266. run_hp_search_ray,
  267. run_hp_search_wandb,
  268. )
  269. from .liger import apply_liger_kernel
  270. from .metal_quantization import (
  271. MetalLinear,
  272. replace_with_metal_linear,
  273. )
  274. from .moe import (
  275. batched_mm_experts_forward,
  276. grouped_mm_experts_forward,
  277. use_experts_implementation,
  278. )
  279. from .mxfp4 import (
  280. Mxfp4GptOssExperts,
  281. dequantize,
  282. load_and_swizzle_mxfp4,
  283. quantize_to_mxfp4,
  284. replace_with_mxfp4_linear,
  285. swizzle_mxfp4,
  286. )
  287. from .neftune import activate_neftune, deactivate_neftune, neftune_post_forward_hook
  288. from .peft import PeftAdapterMixin
  289. from .quanto import replace_with_quanto_layers
  290. from .sinq import SinqDeserialize, SinqQuantize
  291. from .spqr import replace_with_spqr_linear
  292. from .vptq import replace_with_vptq_linear
  293. try:
  294. if not is_torch_available():
  295. raise OptionalDependencyNotAvailable()
  296. except OptionalDependencyNotAvailable:
  297. pass
  298. else:
  299. from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
  300. from .tensor_parallel import (
  301. ALL_PARALLEL_STYLES,
  302. shard_and_distribute_module,
  303. translate_to_torch_parallel_style,
  304. )
  305. try:
  306. if not is_torch_greater_or_equal("2.5"):
  307. raise OptionalDependencyNotAvailable()
  308. except OptionalDependencyNotAvailable:
  309. pass
  310. else:
  311. from .flex_attention import make_flex_block_causal_mask
  312. else:
  313. import sys
  314. sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)