conversion_mapping.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. # Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from copy import deepcopy
  16. from typing import TYPE_CHECKING
  17. from .core_model_loading import (
  18. Chunk,
  19. Concatenate,
  20. ErnieFuseAndSplitTextVisionExperts,
  21. MergeModulelist,
  22. Transpose,
  23. WeightConverter,
  24. WeightRenaming,
  25. )
  26. if TYPE_CHECKING:
  27. from .modeling_utils import PreTrainedModel
  28. from .quantizers import HfQuantizer
  29. _MODEL_TO_CONVERSION_PATTERN = {
  30. # Mixtral-style MoE
  31. "minimax": "mixtral",
  32. "minimax_m2": "mixtral",
  33. # Qwen2-style MoE
  34. "afmoe": "qwen2_moe",
  35. "deepseek_v2": "qwen2_moe",
  36. "deepseek_v3": "qwen2_moe",
  37. "dots1": "qwen2_moe",
  38. "ernie4_5_moe": "qwen2_moe",
  39. "glm4_moe": "qwen2_moe",
  40. "glm4_moe_lite": "qwen2_moe",
  41. "glm_moe_dsa": "qwen2_moe",
  42. "glm4v_moe": "qwen2_moe",
  43. "longcat_flash": "qwen2_moe",
  44. "solar_open": "qwen2_moe",
  45. "qwen3_moe": "qwen2_moe",
  46. "qwen3_omni_moe": "qwen2_moe",
  47. "qwen3_omni_moe_thinker": "qwen2_moe",
  48. "qwen3_next": "qwen2_moe",
  49. "hunyuan_v1_moe": "qwen2_moe",
  50. "flex_olmo": "qwen2_moe",
  51. "olmoe": "qwen2_moe",
  52. "exaone_moe": "qwen2_moe",
  53. "rt_detr_v2": "rt_detr",
  54. "pp_doclayout_v2": "rt_detr",
  55. "pp_doclayout_v3": "rt_detr",
  56. "paligemma": "llava",
  57. "aya_vision": "llava",
  58. "fuyu": "llava",
  59. "got_ocr2": "llava",
  60. "shieldgemma2": "llava",
  61. "gemma3": "llava",
  62. "internvl": "llava",
  63. "llava_next": "llava",
  64. "llava_next_video": "llava",
  65. "llava_onevision": "llava",
  66. "vipllava": "llava",
  67. "video_llava": "llava",
  68. "mistral3": "llava",
  69. "mllama": "llava",
  70. "qwen2_5_vl": "qwen2_vl",
  71. "sam3_tracker_video": "sam3_tracker",
  72. "pp_chart2table": "llava",
  73. "gemma3n_text": "qwen3_5_text",
  74. "qwen3_5_moe_text": "qwen3_5_text",
  75. }
  76. def _build_checkpoint_conversion_mapping():
  77. mapping = {
  78. "llava": [
  79. WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"),
  80. WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"),
  81. ],
  82. "emu3": [
  83. WeightRenaming(source_patterns=r"text_model.model", target_patterns="text_model"),
  84. WeightRenaming(source_patterns=r"text_model.lm_head", target_patterns="lm_head"),
  85. ],
  86. "paddleocr_vl": [
  87. WeightRenaming(source_patterns=r"mlp_AR", target_patterns="model.projector"),
  88. WeightRenaming(
  89. source_patterns=r"^model(?!(\.visual|\.projector|\.language_model))",
  90. target_patterns="model.language_model",
  91. ),
  92. ],
  93. "qwen2_vl": [
  94. WeightRenaming(
  95. source_patterns=r"(?<!_)model(?!\.(language_model|visual))", target_patterns="model.language_model"
  96. ),
  97. ],
  98. "colqwen2": [
  99. WeightRenaming(source_patterns=r"vlm.model", target_patterns="vlm"),
  100. WeightRenaming(source_patterns=r"vlm(?!\.(language_model|visual))", target_patterns="vlm.language_model"),
  101. ],
  102. "timm_wrapper": [
  103. # Simply add the prefix `timm_model`. Similar to `base_model_prefix` but also removes prefix
  104. # when saving. TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
  105. # Note: we don't add `timm_model` when it is part of a bigger VLM, because they already have `timm_model`
  106. # saved in state dict keys. Thus the look behind check. Should be fixed by proper `add_prefix`!
  107. WeightRenaming(
  108. source_patterns=r"^(?!(?:model\.|backbone\.|tower\.))(.+)$",
  109. target_patterns=r"timm_model.\1",
  110. )
  111. ],
  112. "pi0": [
  113. WeightRenaming(source_patterns=r"state_proj", target_patterns="embed_action_time.state_proj"),
  114. WeightRenaming(source_patterns=r"action_in_proj", target_patterns="embed_action_time.action_in_proj"),
  115. WeightRenaming(
  116. source_patterns=r"action_time_mlp_in", target_patterns="embed_action_time.action_time_mlp_in"
  117. ),
  118. WeightRenaming(
  119. source_patterns=r"action_time_mlp_out", target_patterns="embed_action_time.action_time_mlp_out"
  120. ),
  121. WeightRenaming(source_patterns=r"^paligemma_with_expert.paligemma.model", target_patterns="model.vlm"),
  122. WeightRenaming(source_patterns=r"^paligemma_with_expert.gemma_expert.model", target_patterns="model.dit"),
  123. # Weight on the hub have only `lm_head` saved, but PI0 doesn't create any lm-head initialized!
  124. WeightRenaming(
  125. source_patterns=r"^paligemma_with_expert.gemma_expert.lm_head",
  126. target_patterns="model.dit.embed_tokens",
  127. ),
  128. WeightRenaming(
  129. source_patterns=r"^paligemma_with_expert.paligemma.lm_head",
  130. target_patterns="model.vlm.language_model.embed_tokens",
  131. ),
  132. ],
  133. "dinov3_convnext": [WeightRenaming(r"(?<!model\.)stages", r"model.stages")],
  134. "dinov3_vit": [WeightRenaming(r"(?<!model\.)layer.", r"model.layer.")],
  135. "timesfm2_5": [
  136. WeightRenaming("ff0", "fc1"),
  137. WeightRenaming("ff1", "fc2"),
  138. ],
  139. "olmo_hybrid": [
  140. WeightRenaming("attention_layer_norm", "input_layernorm"),
  141. WeightRenaming("feedforward_layer_norm", "post_attention_layernorm"),
  142. ],
  143. "qwen3_5_text": [
  144. # Note: the lookbehind on the target is to avoid replacing bigger matches when the model is a submodel of
  145. # the ForConditionalGeneration model
  146. WeightRenaming(source_patterns=r"^model.language_model.", target_patterns=r"^model.(?!language_model.)"),
  147. ],
  148. "sam3_tracker": [
  149. WeightRenaming(
  150. source_patterns=r"detector_model.vision_encoder.backbone.", target_patterns="vision_encoder.backbone."
  151. ),
  152. WeightRenaming(source_patterns=r"tracker_neck.", target_patterns="vision_encoder.neck."),
  153. ],
  154. "t5gemma2_encoder": [
  155. WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)embed_tokens\.", "text_model.embed_tokens."),
  156. WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)(?<!layer)(?<!_)norm\.", "text_model.norm."),
  157. WeightRenaming(r"(?<!vision_model.encoder\.)(?<!decoder\.)(?<!text_model\.)layers.", "text_model.layers."),
  158. ],
  159. "mixtral": [
  160. WeightRenaming(".block_sparse_moe.", ".mlp."),
  161. WeightConverter(
  162. source_patterns=[
  163. ".experts.*.w1.weight",
  164. ".experts.*.w3.weight",
  165. ], # you give me a list of 2 keys, I collect a list of a list of tensors
  166. target_patterns=".experts.gate_up_proj", # target key gets the list of two tensors
  167. operations=[
  168. MergeModulelist(
  169. dim=0
  170. ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
  171. Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
  172. ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
  173. ),
  174. WeightConverter(
  175. source_patterns=[
  176. ".experts.*.w2.weight",
  177. ],
  178. target_patterns=".experts.down_proj", # target key gets the list of two tensors
  179. operations=[
  180. MergeModulelist(
  181. dim=0
  182. ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
  183. ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
  184. ),
  185. ],
  186. "qwen2_moe": [
  187. WeightConverter(
  188. source_patterns=[
  189. "mlp.experts.*.gate_proj.weight",
  190. "mlp.experts.*.up_proj.weight",
  191. ],
  192. target_patterns="mlp.experts.gate_up_proj",
  193. operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
  194. ),
  195. WeightConverter(
  196. source_patterns="mlp.experts.*.down_proj.weight",
  197. target_patterns="mlp.experts.down_proj",
  198. operations=[MergeModulelist(dim=0)],
  199. ),
  200. ],
  201. "qwen3_vl_moe": [
  202. WeightConverter(
  203. source_patterns="mlp.experts.gate_up_proj",
  204. target_patterns="mlp.experts.gate_up_proj",
  205. operations=[Transpose(1, 2, check_dims=True)],
  206. ),
  207. WeightConverter(
  208. source_patterns="mlp.experts.down_proj",
  209. target_patterns="mlp.experts.down_proj",
  210. operations=[Transpose(1, 2, check_dims=True)],
  211. ),
  212. ],
  213. "phimoe": [
  214. WeightRenaming(".block_sparse_moe.", ".mlp."),
  215. WeightRenaming(".gate.weight", ".router.weight"),
  216. WeightConverter(
  217. source_patterns=[
  218. ".experts.*.w1.weight",
  219. ".experts.*.w3.weight",
  220. ],
  221. target_patterns=".experts.gate_up_proj",
  222. operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
  223. ),
  224. WeightConverter(
  225. source_patterns=".experts.*.w2.weight",
  226. target_patterns=".experts.down_proj",
  227. operations=[MergeModulelist(dim=0)],
  228. ),
  229. ],
  230. "lfm2_moe": [
  231. WeightConverter(
  232. source_patterns=[
  233. "feed_forward.experts.*.w1.weight",
  234. "feed_forward.experts.*.w3.weight",
  235. ],
  236. target_patterns="feed_forward.experts.gate_up_proj",
  237. operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
  238. ),
  239. WeightConverter(
  240. source_patterns="feed_forward.experts.*.w2.weight",
  241. target_patterns="feed_forward.experts.down_proj",
  242. operations=[MergeModulelist(dim=0)],
  243. ),
  244. ],
  245. "ernie4_5_vl_moe": [
  246. # vision
  247. WeightRenaming("vision_model", "vision_tower"),
  248. # resampler
  249. WeightRenaming("spatial_linear.0", "spatial_linear.fc1"),
  250. WeightRenaming("spatial_linear.2", "spatial_linear.fc2"),
  251. WeightRenaming("spatial_linear.3", "spatial_linear.ln"),
  252. WeightRenaming("temporal_linear.0", "temporal_linear.fc1"),
  253. WeightRenaming("temporal_linear.2", "temporal_linear.fc2"),
  254. WeightRenaming("temporal_linear.3", "temporal_linear.ln"),
  255. # language model
  256. WeightRenaming(r"(?<!language_model\.)embed_tokens", "language_model.embed_tokens"),
  257. WeightRenaming(r"(?<!language_model\.)layers", "language_model.layers"),
  258. WeightRenaming(r"(?<!_)(?<!\w)norm\.", "language_model.norm."),
  259. WeightConverter(
  260. source_patterns="mlp.gate.weight_1",
  261. target_patterns="mlp.vision_moe.gate.weight",
  262. operations=[Transpose(dim0=0, dim1=1)],
  263. ),
  264. WeightConverter(
  265. source_patterns="mlp.gate.weight",
  266. target_patterns="mlp.text_moe.gate.weight",
  267. operations=[Transpose(dim0=0, dim1=1)],
  268. ),
  269. WeightConverter(
  270. source_patterns=["mlp.moe_statics.e_score_correction_bias"],
  271. target_patterns=[
  272. "mlp.text_moe.gate.moe_statics.e_score_correction_bias",
  273. "mlp.vision_moe.gate.moe_statics.e_score_correction_bias",
  274. ],
  275. operations=[Chunk(dim=0)],
  276. ),
  277. WeightConverter(
  278. source_patterns=["experts.*.down_proj.weight"],
  279. target_patterns=[
  280. "text_moe.experts.down_proj",
  281. "vision_moe.experts.down_proj",
  282. ],
  283. operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
  284. ),
  285. WeightConverter(
  286. source_patterns=[
  287. "experts.*.gate_proj.weight",
  288. "experts.*.up_proj.weight",
  289. ],
  290. target_patterns=[
  291. "text_moe.experts.gate_up_proj",
  292. "vision_moe.experts.gate_up_proj",
  293. ],
  294. operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
  295. ),
  296. ],
  297. "detr": [
  298. WeightRenaming("backbone.conv_encoder", "backbone"),
  299. WeightRenaming("out_proj", "o_proj"),
  300. WeightRenaming(r"layers.(\d+).fc1", r"layers.\1.mlp.fc1"),
  301. WeightRenaming(r"layers.(\d+).fc2", r"layers.\1.mlp.fc2"),
  302. # `DetrForSegmentation`
  303. WeightRenaming("bbox_attention.q_linear", "bbox_attention.q_proj"),
  304. WeightRenaming("bbox_attention.k_linear", "bbox_attention.k_proj"),
  305. # Mask head refactor
  306. WeightRenaming("mask_head.lay1", "mask_head.conv1.conv"),
  307. WeightRenaming("mask_head.gn1", "mask_head.conv1.norm"),
  308. WeightRenaming("mask_head.lay2", "mask_head.conv2.conv"),
  309. WeightRenaming("mask_head.gn2", "mask_head.conv2.norm"),
  310. WeightRenaming("mask_head.adapter1", "mask_head.fpn_stages.0.fpn_adapter"),
  311. WeightRenaming("mask_head.lay3", "mask_head.fpn_stages.0.refine.conv"),
  312. WeightRenaming("mask_head.gn3", "mask_head.fpn_stages.0.refine.norm"),
  313. WeightRenaming("mask_head.adapter2", "mask_head.fpn_stages.1.fpn_adapter"),
  314. WeightRenaming("mask_head.lay4", "mask_head.fpn_stages.1.refine.conv"),
  315. WeightRenaming("mask_head.gn4", "mask_head.fpn_stages.1.refine.norm"),
  316. WeightRenaming("mask_head.adapter3", "mask_head.fpn_stages.2.fpn_adapter"),
  317. WeightRenaming("mask_head.lay5", "mask_head.fpn_stages.2.refine.conv"),
  318. WeightRenaming("mask_head.gn5", "mask_head.fpn_stages.2.refine.norm"),
  319. WeightRenaming("mask_head.out_lay", "mask_head.output_conv"),
  320. ],
  321. "rt_detr": [
  322. WeightRenaming("out_proj", "o_proj"),
  323. WeightRenaming(r"layers.(\d+).fc1", r"layers.\1.mlp.fc1"),
  324. WeightRenaming(r"layers.(\d+).fc2", r"layers.\1.mlp.fc2"),
  325. WeightRenaming(r"encoder.encoder.(\d+).layers", r"encoder.aifi.\1.layers"),
  326. ],
  327. "conditional_detr": [
  328. WeightRenaming("backbone.conv_encoder", "backbone"),
  329. WeightRenaming("self_attn.out_proj", "self_attn.o_proj"),
  330. WeightRenaming("encoder_attn.out_proj", "encoder_attn.o_proj"),
  331. WeightRenaming(r"layers.(\d+).fc1", r"layers.\1.mlp.fc1"),
  332. WeightRenaming(r"layers.(\d+).fc2", r"layers.\1.mlp.fc2"),
  333. # Decoder self-attention projections moved into self_attn module
  334. WeightRenaming(r"decoder.layers.(\d+).sa_qcontent_proj", r"decoder.layers.\1.self_attn.q_content_proj"),
  335. WeightRenaming(r"decoder.layers.(\d+).sa_qpos_proj", r"decoder.layers.\1.self_attn.q_pos_proj"),
  336. WeightRenaming(r"decoder.layers.(\d+).sa_kcontent_proj", r"decoder.layers.\1.self_attn.k_content_proj"),
  337. WeightRenaming(r"decoder.layers.(\d+).sa_kpos_proj", r"decoder.layers.\1.self_attn.k_pos_proj"),
  338. WeightRenaming(r"decoder.layers.(\d+).sa_v_proj", r"decoder.layers.\1.self_attn.v_proj"),
  339. # Decoder cross-attention projections moved into encoder_attn module
  340. WeightRenaming(r"decoder.layers.(\d+).ca_qcontent_proj", r"decoder.layers.\1.encoder_attn.q_content_proj"),
  341. WeightRenaming(r"decoder.layers.(\d+).ca_qpos_proj", r"decoder.layers.\1.encoder_attn.q_pos_proj"),
  342. WeightRenaming(r"decoder.layers.(\d+).ca_kcontent_proj", r"decoder.layers.\1.encoder_attn.k_content_proj"),
  343. WeightRenaming(r"decoder.layers.(\d+).ca_kpos_proj", r"decoder.layers.\1.encoder_attn.k_pos_proj"),
  344. WeightRenaming(r"decoder.layers.(\d+).ca_v_proj", r"decoder.layers.\1.encoder_attn.v_proj"),
  345. WeightRenaming(
  346. r"decoder.layers.(\d+).ca_qpos_sine_proj", r"decoder.layers.\1.encoder_attn.q_pos_sine_proj"
  347. ),
  348. # The rest of patterns are used only in `ConditionalDetrForSegmentation`
  349. WeightRenaming("bbox_attention.q_linear", "bbox_attention.q_proj"),
  350. WeightRenaming("bbox_attention.k_linear", "bbox_attention.k_proj"),
  351. # Mask head refactor
  352. WeightRenaming("mask_head.lay1", "mask_head.conv1.conv"),
  353. WeightRenaming("mask_head.gn1", "mask_head.conv1.norm"),
  354. WeightRenaming("mask_head.lay2", "mask_head.conv2.conv"),
  355. WeightRenaming("mask_head.gn2", "mask_head.conv2.norm"),
  356. WeightRenaming("mask_head.adapter1", "mask_head.fpn_stages.0.fpn_adapter"),
  357. WeightRenaming("mask_head.lay3", "mask_head.fpn_stages.0.refine.conv"),
  358. WeightRenaming("mask_head.gn3", "mask_head.fpn_stages.0.refine.norm"),
  359. WeightRenaming("mask_head.adapter2", "mask_head.fpn_stages.1.fpn_adapter"),
  360. WeightRenaming("mask_head.lay4", "mask_head.fpn_stages.1.refine.conv"),
  361. WeightRenaming("mask_head.gn4", "mask_head.fpn_stages.1.refine.norm"),
  362. WeightRenaming("mask_head.adapter3", "mask_head.fpn_stages.2.fpn_adapter"),
  363. WeightRenaming("mask_head.lay5", "mask_head.fpn_stages.2.refine.conv"),
  364. WeightRenaming("mask_head.gn5", "mask_head.fpn_stages.2.refine.norm"),
  365. WeightRenaming("mask_head.out_lay", "mask_head.output_conv"),
  366. ],
  367. "deformable_detr": [
  368. WeightRenaming("backbone.conv_encoder", "backbone"),
  369. WeightRenaming("self_attn.out_proj", "self_attn.o_proj"),
  370. WeightRenaming(r"layers.(\d+).fc1", r"layers.\1.mlp.fc1"),
  371. WeightRenaming(r"layers.(\d+).fc2", r"layers.\1.mlp.fc2"),
  372. ],
  373. "d_fine": [
  374. WeightRenaming("out_proj", "o_proj"),
  375. WeightRenaming(r"layers.(\d+).fc1", r"layers.\1.mlp.layers.0"),
  376. WeightRenaming(r"layers.(\d+).fc2", r"layers.\1.mlp.layers.1"),
  377. WeightRenaming(r"encoder.encoder.(\d+).layers", r"encoder.aifi.\1.layers"),
  378. ],
  379. "nemotron_h": [
  380. WeightRenaming("backbone.", "model."),
  381. WeightRenaming("embedding.weight", "embeddings.weight"),
  382. WeightConverter(
  383. source_patterns=[
  384. "mixer.experts.*.up_proj.weight",
  385. ],
  386. target_patterns="mixer.experts.up_proj",
  387. operations=[MergeModulelist(dim=0)],
  388. ),
  389. WeightConverter(
  390. source_patterns=[
  391. "mixer.experts.*.down_proj.weight",
  392. ],
  393. target_patterns="mixer.experts.down_proj",
  394. operations=[MergeModulelist(dim=0)],
  395. ),
  396. ],
  397. "jamba": [
  398. WeightConverter(
  399. source_patterns=[
  400. "feed_forward.experts.*.gate_proj.weight",
  401. "feed_forward.experts.*.up_proj.weight",
  402. ],
  403. target_patterns="feed_forward.experts.gate_up_proj",
  404. operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
  405. ),
  406. WeightConverter(
  407. source_patterns="feed_forward.experts.*.down_proj.weight",
  408. target_patterns="feed_forward.experts.down_proj",
  409. operations=[MergeModulelist(dim=0)],
  410. ),
  411. ],
  412. "legacy": [
  413. WeightRenaming(
  414. source_patterns="LayerNorm.gamma",
  415. target_patterns="LayerNorm.weight",
  416. ),
  417. WeightRenaming(
  418. source_patterns="LayerNorm.beta",
  419. target_patterns="LayerNorm.bias",
  420. ),
  421. ],
  422. "nomic_bert": [
  423. WeightRenaming(r"encoder.layers", r"layers"),
  424. WeightRenaming(r"emb_ln", r"embeddings.LayerNorm"),
  425. WeightRenaming(r"attn.out_proj", r"self_attn.o_proj"),
  426. WeightRenaming(r"fc11", r"up_proj"),
  427. WeightRenaming(r"fc12", r"gate_proj"),
  428. WeightRenaming(r"fc2", r"down_proj"),
  429. WeightRenaming(r"norm1", r"post_attention_layernorm"),
  430. WeightRenaming(
  431. r"norm2",
  432. r"post_mlp_layernorm",
  433. ),
  434. WeightConverter(
  435. source_patterns=["attn.Wqkv"],
  436. target_patterns=[
  437. "self_attn.q_proj",
  438. "self_attn.k_proj",
  439. "self_attn.v_proj",
  440. ],
  441. operations=[Chunk(dim=0)],
  442. ),
  443. ],
  444. "jina_embeddings_v3": [
  445. WeightRenaming(source_patterns="emb_ln", target_patterns="embeddings.LayerNorm"),
  446. WeightRenaming(source_patterns="encoder.layers", target_patterns="layers"),
  447. WeightConverter(
  448. source_patterns="mixer.Wqkv",
  449. target_patterns=[
  450. "self_attn.q_proj",
  451. "self_attn.k_proj",
  452. "self_attn.v_proj",
  453. ],
  454. operations=[Chunk(dim=0)],
  455. ),
  456. WeightRenaming(source_patterns="mixer.out_proj", target_patterns="self_attn.o_proj"),
  457. WeightRenaming(source_patterns="norm1", target_patterns="post_attention_layernorm"),
  458. WeightRenaming(source_patterns="norm2", target_patterns="post_mlp_layernorm"),
  459. ],
  460. }
  461. mapping["legacy"] += [
  462. WeightRenaming(
  463. source_patterns=".weight_g$",
  464. target_patterns=".parametrizations.weight.original0",
  465. ),
  466. WeightRenaming(
  467. source_patterns=".weight_v$",
  468. target_patterns=".parametrizations.weight.original1",
  469. ),
  470. ]
  471. mapping["ernie4_5_moe"] = [
  472. WeightRenaming("mlp.moe_statics.e_score_correction_bias", "mlp.gate.moe_statics.e_score_correction_bias"),
  473. WeightConverter(
  474. source_patterns=[
  475. "mlp.experts.*.gate_proj.weight",
  476. "mlp.experts.*.up_proj.weight",
  477. ],
  478. target_patterns="mlp.experts.gate_up_proj",
  479. operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
  480. ),
  481. WeightConverter(
  482. source_patterns="mlp.experts.*.down_proj.weight",
  483. target_patterns="mlp.experts.down_proj",
  484. operations=[MergeModulelist(dim=0)],
  485. ),
  486. ]
  487. mapping["minimax_m2"] = mapping["mixtral"].copy()
  488. mapping["minimax_m2"] += [
  489. WeightRenaming(".block_sparse_moe.e_score_correction_bias", ".mlp.e_score_correction_bias"),
  490. ]
  491. mapping["exaone_moe"] = mapping["qwen2_moe"].copy()
  492. mapping["exaone_moe"] += [WeightRenaming("mlp.e_score_correction_bias", "mlp.gate.e_score_correction_bias")]
  493. mapping["solar_open"] = [
  494. WeightConverter(
  495. source_patterns=[
  496. "mlp.experts.*.gate_proj.weight",
  497. "mlp.experts.*.up_proj.weight",
  498. ],
  499. target_patterns="mlp.experts.gate_up_proj",
  500. operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
  501. ),
  502. WeightConverter(
  503. source_patterns="mlp.experts.*.down_proj.weight",
  504. target_patterns="mlp.experts.down_proj",
  505. operations=[MergeModulelist(dim=0)],
  506. ),
  507. ]
  508. mapping["cohere_asr"] = [
  509. WeightRenaming(r"encoder\.pre_encode\.conv\.", r"encoder.subsampling.layers."),
  510. WeightRenaming(r"encoder\.pre_encode\.out\.", r"encoder.subsampling.linear."),
  511. WeightRenaming(r"transf_decoder\._embedding\.position_embedding\.pos_enc", r"decoder.pos_emb.weight"),
  512. WeightRenaming(r"transf_decoder\._embedding\.token_embedding", r"decoder.embed_tokens"),
  513. WeightRenaming(r"transf_decoder\._embedding\.layer_norm", r"decoder.embedding_layernorm"),
  514. WeightRenaming(r"transf_decoder\._decoder\.final_layer_norm", r"decoder.norm"),
  515. WeightRenaming(r"transf_decoder\._decoder\.layers", r"decoder.layers"),
  516. WeightRenaming(r"encoder_decoder_proj\.", r"decoder.proj."),
  517. WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_q", r"encoder.(.+).self_attn.q_proj"),
  518. WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_k", r"encoder.(.+).self_attn.k_proj"),
  519. WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_v", r"encoder.(.+).self_attn.v_proj"),
  520. WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_out", r"encoder.(.+).self_attn.o_proj"),
  521. WeightRenaming(r"encoder\.(.+)\.self_attn\.linear_pos", r"encoder.(.+).self_attn.relative_k_proj"),
  522. WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_u", r"encoder.(.+).self_attn.bias_u"),
  523. WeightRenaming(r"encoder\.(.+)\.self_attn\.pos_bias_v", r"encoder.(.+).self_attn.bias_v"),
  524. WeightRenaming(r"\.first_sub_layer\.query_net", r".self_attn.q_proj"),
  525. WeightRenaming(r"\.first_sub_layer\.key_net", r".self_attn.k_proj"),
  526. WeightRenaming(r"\.first_sub_layer\.value_net", r".self_attn.v_proj"),
  527. WeightRenaming(r"\.first_sub_layer\.out_projection", r".self_attn.o_proj"),
  528. WeightRenaming(r"\.second_sub_layer\.query_net", r".encoder_attn.q_proj"),
  529. WeightRenaming(r"\.second_sub_layer\.key_net", r".encoder_attn.k_proj"),
  530. WeightRenaming(r"\.second_sub_layer\.value_net", r".encoder_attn.v_proj"),
  531. WeightRenaming(r"\.second_sub_layer\.out_projection", r".encoder_attn.o_proj"),
  532. WeightRenaming(r"\.third_sub_layer\.dense_in", r".mlp.fc1"),
  533. WeightRenaming(r"\.third_sub_layer\.dense_out", r".mlp.fc2"),
  534. WeightRenaming(r"\.layer_norm_1\.", r".input_layernorm."),
  535. WeightRenaming(r"\.layer_norm_2\.", r".post_attention_layernorm."),
  536. WeightRenaming(r"\.layer_norm_3\.", r".final_layernorm."),
  537. WeightRenaming(r"\.conv\.batch_norm", r".conv.norm"),
  538. WeightRenaming(r"log_softmax\.mlp\.layer0", r"proj_out"),
  539. ]
  540. for model_type, base_pattern in _MODEL_TO_CONVERSION_PATTERN.items():
  541. if model_type in mapping:
  542. continue
  543. mapping[model_type] = mapping[base_pattern].copy()
  544. return mapping
  545. _checkpoint_conversion_mapping_cache = None
  546. def get_checkpoint_conversion_mapping(model_type):
  547. global _checkpoint_conversion_mapping_cache
  548. if _checkpoint_conversion_mapping_cache is None:
  549. _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
  550. return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type))
  551. def register_checkpoint_conversion_mapping(
  552. model_type: str,
  553. mapping: list[WeightConverter | WeightRenaming],
  554. overwrite: bool = False,
  555. ) -> None:
  556. global _checkpoint_conversion_mapping_cache
  557. if _checkpoint_conversion_mapping_cache is None:
  558. _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
  559. if model_type in _checkpoint_conversion_mapping_cache and not overwrite:
  560. raise ValueError(f"Model type {model_type} already exists in the checkpoint conversion mapping.")
  561. _checkpoint_conversion_mapping_cache[model_type] = mapping
  562. def extract_weight_conversions_for_model(model: PreTrainedModel) -> list[WeightConverter | WeightRenaming] | None:
  563. model_type = getattr(model.config, "model_type", None)
  564. if model_type is not None:
  565. model_specific_conversions = get_checkpoint_conversion_mapping(model_type)
  566. return model_specific_conversions
  567. return None
  568. def get_model_conversion_mapping(
  569. model: PreTrainedModel,
  570. key_mapping: dict[str, str] | None = None,
  571. hf_quantizer: HfQuantizer | None = None,
  572. add_legacy: bool = True,
  573. ) -> list[WeightConverter | WeightRenaming]:
  574. """
  575. For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming
  576. `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping.
  577. """
  578. # Lazy import to avoid circular import issues
  579. from .modeling_utils import PreTrainedModel
  580. # note: this function is used in PEFT, so changing the API requires coordination
  581. weight_conversions = []
  582. # Load models with explicit, user-provided key mapping
  583. if key_mapping is not None:
  584. weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()]
  585. # Model have several `PreTrainedModel` within with the same model type
  586. # For ex: XForConditionalGeneration -> XModel. We don't want to apply the same
  587. # conversion pattern twice because of that
  588. seen_model_types = set()
  589. if (conversions := extract_weight_conversions_for_model(model)) is not None:
  590. weight_conversions.extend(conversions)
  591. seen_model_types.add(model.config.model_type)
  592. # Recurse over submodules and collect all conversions
  593. for submodule in model.modules():
  594. if (
  595. submodule is not model
  596. and isinstance(submodule, PreTrainedModel)
  597. and submodule.config.model_type not in seen_model_types
  598. ):
  599. conversions = extract_weight_conversions_for_model(submodule)
  600. if conversions is not None:
  601. weight_conversions.extend(conversions)
  602. seen_model_types.add(submodule.config.model_type)
  603. if add_legacy:
  604. weight_conversions.extend(get_checkpoint_conversion_mapping("legacy"))
  605. # Add the ones from the quantizer as well if provided
  606. if hf_quantizer is not None:
  607. weight_conversions.extend(hf_quantizer.get_weight_conversions())
  608. return weight_conversions