modeling_rope_utils.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import warnings
  16. from collections.abc import Callable
  17. from functools import wraps
  18. from typing import TYPE_CHECKING, Optional, TypedDict
  19. from .utils import is_torch_available, logging
  20. logger = logging.get_logger(__name__)
  21. if is_torch_available():
  22. import torch
  23. if TYPE_CHECKING:
  24. from .configuration_utils import PreTrainedConfig
  25. def dynamic_rope_update(rope_forward):
  26. """
  27. Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
  28. (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
  29. Args:
  30. rope_forward (Callable):
  31. The forward pass of the RoPE implementation.
  32. Returns:
  33. The decorated forward pass.
  34. """
  35. def longrope_frequency_update(self, position_ids, device, layer_type=None):
  36. """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
  37. seq_len = torch.max(position_ids) + 1
  38. if layer_type is None:
  39. rope_type = self.rope_type
  40. original_inv_freq = self.original_inv_freq
  41. prefix = ""
  42. original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"]
  43. else:
  44. rope_type = self.rope_type[layer_type]
  45. original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
  46. prefix = f"{layer_type}_"
  47. original_max_position_embeddings = self.config.rope_parameters[layer_type][
  48. "original_max_position_embeddings"
  49. ]
  50. if seq_len > original_max_position_embeddings:
  51. if not hasattr(self, f"{layer_type}_long_inv_freq"):
  52. rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
  53. long_inv_freq, _ = rope_init_fn(
  54. self.config,
  55. device,
  56. seq_len=original_max_position_embeddings + 1,
  57. layer_type=layer_type,
  58. )
  59. self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False)
  60. setattr(self, f"{prefix}long_inv_freq", long_inv_freq)
  61. else:
  62. # This .to() is needed if the model has been moved to a device after being initialized (because
  63. # the buffer is automatically moved, but not the original copy)
  64. original_inv_freq = original_inv_freq.to(device)
  65. self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
  66. setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
  67. def dynamic_frequency_update(self, position_ids, device, layer_type=None):
  68. """
  69. dynamic RoPE layers should recompute `inv_freq` in the following situations:
  70. 1 - growing beyond the cached sequence length (allow scaling)
  71. 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
  72. """
  73. seq_len = torch.max(position_ids) + 1
  74. if layer_type is None:
  75. rope_type = self.rope_type
  76. max_seq_len_cached = self.max_seq_len_cached
  77. original_inv_freq = self.original_inv_freq
  78. prefix = ""
  79. else:
  80. rope_type = self.rope_type[layer_type]
  81. max_seq_len_cached = getattr(self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached)
  82. original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
  83. prefix = f"{layer_type}_"
  84. if seq_len > max_seq_len_cached: # growth
  85. rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
  86. inv_freq, self.attention_scaling = rope_init_fn(
  87. self.config,
  88. device,
  89. seq_len=seq_len,
  90. layer_type=layer_type,
  91. )
  92. # TODO joao: may break with compilation
  93. self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False)
  94. setattr(self, f"{layer_type}_max_seq_len_cached", seq_len)
  95. if seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len: # reset
  96. # This .to() is needed if the model has been moved to a device after being initialized (because
  97. # the buffer is automatically moved, but not the original copy)
  98. original_inv_freq = original_inv_freq.to(device)
  99. self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False)
  100. setattr(self, f"{prefix}original_inv_freq", original_inv_freq)
  101. setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len)
  102. @wraps(rope_forward)
  103. def wrapper(self, x, position_ids, layer_type=None):
  104. rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
  105. kwargs = {"layer_type": layer_type} if layer_type is not None else {}
  106. if "dynamic" in rope_type:
  107. dynamic_frequency_update(self, position_ids, device=x.device, **kwargs)
  108. elif rope_type == "longrope":
  109. longrope_frequency_update(self, position_ids, device=x.device, **kwargs)
  110. return rope_forward(self, x, position_ids, **kwargs)
  111. return wrapper
  112. def _compute_linear_scaling_rope_parameters(
  113. config: Optional["PreTrainedConfig"] = None,
  114. device: Optional["torch.device"] = None,
  115. seq_len: int | None = None,
  116. layer_type: str | None = None,
  117. ) -> tuple["torch.Tensor", float]:
  118. """
  119. Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
  120. Args:
  121. config ([`~transformers."PreTrainedConfig"`]):
  122. The model configuration. This function assumes that the config will provide at least the following
  123. properties:
  124. * rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
  125. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  126. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  127. Additionally, this function will make use of the following properties if they are found in the config:
  128. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  129. derived as hidden_size // num_attention_heads.
  130. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  131. the first fraction of the head_dim. Defaults to 1.0.
  132. device (`torch.device`):
  133. The device to use for initialization of the inverse frequencies.
  134. seq_len (`int`, *optional*):
  135. The current sequence length. Unused for this type of RoPE.
  136. Returns:
  137. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  138. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  139. """
  140. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  141. config.standardize_rope_params()
  142. rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
  143. factor = rope_parameters_dict["factor"]
  144. # Gets the default RoPE parameters
  145. base = rope_parameters_dict["rope_theta"]
  146. partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
  147. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  148. dim = int(head_dim * partial_rotary_factor)
  149. attention_factor = 1.0 # Unused in this type of RoPE
  150. # Compute the inverse frequencies
  151. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
  152. # Then applies linear scaling to the frequencies.
  153. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
  154. # applying scaling to the inverse frequencies is equivalent.
  155. inv_freq /= factor
  156. return inv_freq, attention_factor
  157. def _compute_proportional_rope_parameters(
  158. config: Optional["PreTrainedConfig"] = None,
  159. device: Optional["torch.device"] = None,
  160. seq_len: int | None = None,
  161. layer_type: str | None = None,
  162. head_dim_key: str = "head_dim",
  163. ) -> tuple["torch.Tensor", float]:
  164. """
  165. Computes the inverse frequencies with proportional RoPE.
  166. Args:
  167. config ([`~transformers.PretrainedConfig`]):
  168. The model configuration. This function assumes that the config will provide at least the following
  169. properties:
  170. * rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
  171. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  172. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  173. Additionally, this function will make use of the following properties if they are found in the config:
  174. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  175. derived as hidden_size // num_attention_heads.
  176. * partial_rotary_factor (`float`, *optional*, defaults to 1.0): The proportion of the embedding dimension
  177. to apply rotary positional encoding, e.g., [0.0, 0.25, 0.5, 0.75, 1.0]. Unlike other RoPE functions
  178. that use this parameter, proportional RoPE will always return an encoding that is the size of
  179. `head_dim`.
  180. device (`torch.device`):
  181. The device to use for initialization of the inverse frequencies.
  182. seq_len (`int`, *optional*):
  183. The current sequence length. Unused for this type of RoPE.
  184. Returns:
  185. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  186. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  187. """
  188. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  189. config.standardize_rope_params()
  190. rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
  191. head_dim = getattr(config, head_dim_key, None) or config.hidden_size // config.num_attention_heads
  192. base = rope_parameters_dict["rope_theta"]
  193. factor = rope_parameters_dict.get("factor", 1.0)
  194. rope_proportion = rope_parameters_dict.get("partial_rotary_factor", 1.0)
  195. attention_factor = 1.0 # Unused in this type of RoPE
  196. rope_angles = int(rope_proportion * head_dim // 2)
  197. inv_freq_rotated = 1.0 / (
  198. base
  199. ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / head_dim)
  200. )
  201. nope_angles = head_dim // 2 - rope_angles
  202. if nope_angles > 0:
  203. inv_freq = torch.cat(
  204. (
  205. inv_freq_rotated,
  206. torch.zeros(nope_angles, dtype=torch.float32, device=device),
  207. ),
  208. dim=0,
  209. )
  210. else:
  211. inv_freq = inv_freq_rotated
  212. inv_freq /= factor
  213. return inv_freq, attention_factor
  214. def _compute_dynamic_ntk_parameters(
  215. config: Optional["PreTrainedConfig"] = None,
  216. device: Optional["torch.device"] = None,
  217. seq_len: int | None = None,
  218. layer_type: str | None = None,
  219. ) -> tuple["torch.Tensor", float]:
  220. """
  221. Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
  222. Args:
  223. config ([`~transformers."PreTrainedConfig"`]):
  224. The model configuration. This function assumes that the config will provide at least the following
  225. properties:
  226. * rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
  227. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  228. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  229. * max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at
  230. inference time
  231. * rope_parameters (`dict[str, float]`): The standard RoPE scaling parameters, from which `factor`
  232. will be accessed. The value of `factor` is used to determine the new base frequency, along with the
  233. current sequence length (seq_len), the maximum positional embeddings (max_position_embeddings), and the
  234. computed dimensionality (dim) of the rotary embeddings. If seq_len <= max_position_embeddings, this
  235. factor has no effect. If seq_len <= max_position_embeddings, this factor effectively stretches the
  236. context window using an exponent derived from `dim`.
  237. Additionally, this function will make use of the following properties if they are found in the config:
  238. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  239. derived as hidden_size // num_attention_heads.
  240. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  241. the first fraction of the head_dim. Defaults to 1.0.
  242. device (`torch.device`):
  243. The device to use for initialization of the inverse frequencies.
  244. seq_len (`int`, *optional*):
  245. The current sequence length, used to update the dynamic RoPE at inference time. If `None` or shorter than
  246. max_position_embeddings, this value will be overridden by max_position_embeddings.
  247. Returns:
  248. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  249. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  250. """
  251. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  252. config.standardize_rope_params()
  253. rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
  254. base = rope_parameters_dict["rope_theta"]
  255. partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
  256. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  257. dim = int(head_dim * partial_rotary_factor)
  258. factor = rope_parameters_dict["factor"]
  259. attention_factor = 1.0 # Unused in this type of RoPE
  260. # seq_len: default to max_position_embeddings, e.g. at init time
  261. if seq_len is None:
  262. seq_len = config.max_position_embeddings
  263. elif isinstance(seq_len, torch.Tensor):
  264. seq_len = torch.maximum(
  265. seq_len,
  266. torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
  267. )
  268. else:
  269. seq_len = max(seq_len, config.max_position_embeddings)
  270. # Compute the inverse frequencies
  271. base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
  272. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
  273. return inv_freq, attention_factor
  274. def _compute_yarn_parameters(
  275. config: "PreTrainedConfig",
  276. device: Optional["torch.device"] = None,
  277. seq_len: int | None = None,
  278. layer_type: str | None = None,
  279. ) -> tuple["torch.Tensor", float]:
  280. """
  281. Computes the inverse frequencies with NTK scaling. Please refer to the
  282. [original paper](https://huggingface.co/papers/2309.00071)
  283. Args:
  284. config ([`~transformers."PreTrainedConfig"`]):
  285. The model configuration. This function assumes that the config will provide at least the following
  286. properties:
  287. * rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
  288. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  289. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  290. * max_position_embeddings (`int`): The maximum length of the positional embeddings.
  291. * rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
  292. keys will be accessed:
  293. * `attention_factor` (`float`, *optional*): The scaling factor to be applied to the computed cos/sin.
  294. If None, the value is inferred from `factor`, `mscale`, and `mscale_all_dim` as available.
  295. * `beta_fast` (`float`, *optional*, defaults to 32): Parameter to set the boundary for extrapolation
  296. (only) in the linear ramp function.
  297. * `beta_slow` (`float`, *optional*, defaults to 1): Parameter to set the boundary for interpolation
  298. (only) in the linear ramp function.
  299. * `factor` (`float`, *optional*): The scaling factor applied when interpolating the position IDs to
  300. extend the possible context length. Additionally, if `attention_factor` is None, the log of this
  301. value is used to compute a value for `attention_factor`, possibly in conjunciton with `mscale` and
  302. `mscale_all_dim`, if provided.
  303. * `mscale` (`float`, *optional*): If `attention_factor` is None and both `mscale` and
  304. `mscale_all_dim` are provided, `mscale` acts scalar augmenting `log(factor)` when computing the
  305. numerator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be
  306. calculated based on `factor` only.
  307. * `mscale_all_dim` (`float`, *optional*): If `attention_factor` is None and both `mscale` and
  308. `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing
  309. the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor`
  310. will be calculated based on `factor` only.
  311. * `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining.
  312. * `truncate` (`bool`, *optional*): Whether to truncate the correction range.
  313. Additionally, this function will make use of the following properties if they are found in the config:
  314. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  315. derived as hidden_size // num_attention_heads.
  316. * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies
  317. will be returned for the first fraction of the head_dim.
  318. device (`torch.device`):
  319. The device to use for initialization of the inverse frequencies.
  320. seq_len (`int`, *optional*):
  321. The current sequence length. Unused for this type of RoPE.
  322. Returns:
  323. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  324. post-processing scaling factor applied to the computed cos/sin.
  325. """
  326. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  327. config.standardize_rope_params()
  328. rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
  329. base = rope_parameters_dict["rope_theta"]
  330. partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
  331. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  332. dim = int(head_dim * partial_rotary_factor)
  333. factor = rope_parameters_dict["factor"]
  334. attention_factor = rope_parameters_dict.get("attention_factor")
  335. mscale = rope_parameters_dict.get("mscale")
  336. mscale_all_dim = rope_parameters_dict.get("mscale_all_dim")
  337. original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]
  338. # NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field
  339. # containing the pretrained value. They use the ratio between `max_position_embeddings` and this value
  340. # to compute the default attention scaling factor, instead of using `factor`.
  341. if factor is None:
  342. factor = config.max_position_embeddings / original_max_position_embeddings
  343. def get_mscale(scale, mscale=1):
  344. if scale <= 1:
  345. return 1.0
  346. return 0.1 * mscale * math.log(scale) + 1.0
  347. # Sets the attention factor as suggested in the paper
  348. if attention_factor is None:
  349. if mscale and mscale_all_dim:
  350. attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
  351. else:
  352. attention_factor = get_mscale(factor)
  353. # Optional config options
  354. # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
  355. beta_fast = rope_parameters_dict.get("beta_fast") or 32
  356. beta_slow = rope_parameters_dict.get("beta_slow") or 1
  357. # Compute the inverse frequencies
  358. def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
  359. """Inverse dimension formula to find the dimension based on the number of rotations"""
  360. return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
  361. def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
  362. """Find dimension range bounds based on rotations"""
  363. low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
  364. high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
  365. if truncate:
  366. low = math.floor(low)
  367. high = math.ceil(high)
  368. return max(low, 0), min(high, dim - 1)
  369. def linear_ramp_factor(min, max, dim):
  370. if min == max:
  371. max += 0.001 # Prevent singularity
  372. linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
  373. ramp_func = torch.clamp(linear_func, 0, 1)
  374. return ramp_func
  375. # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
  376. # to expand the possible context length. In other words, interpolation = apply scaling factor.
  377. pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
  378. inv_freq_extrapolation = 1.0 / pos_freqs
  379. inv_freq_interpolation = 1.0 / (factor * pos_freqs)
  380. truncate = config.rope_parameters.get("truncate", True)
  381. low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
  382. # Get n-dimensional rotational scaling corrected for extrapolation
  383. inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
  384. inv_freq = (
  385. inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
  386. + inv_freq_extrapolation * inv_freq_extrapolation_factor
  387. )
  388. return inv_freq, attention_factor
  389. def _compute_longrope_parameters(
  390. config: "PreTrainedConfig",
  391. device: Optional["torch.device"] = None,
  392. seq_len: int | None = None,
  393. layer_type: str | None = None,
  394. ) -> tuple["torch.Tensor", float]:
  395. """
  396. Computes the inverse frequencies with LongRoPE scaling. Please refer to the
  397. [original implementation](https://github.com/microsoft/LongRoPE)
  398. Args:
  399. config ([`~transformers."PreTrainedConfig"`]):
  400. The model configuration. This function assumes that the config will provide at least the following
  401. properties:
  402. * rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
  403. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  404. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  405. * max_position_embeddings (`int`): The maximum length of the positional embeddings.
  406. * original_max_position_embeddings (`int`, *optional*): The original max position embeddings used during
  407. pretraining. If not provided, defaults to `max_position_embeddings`.
  408. * rope_parameters (`dict[str, float]`): The standard RoPE scaling parameters, from which the following keys
  409. will be accessed:
  410. * `attention_factor` (`float`, *optional*): The scaling factor to be applied on the attention
  411. computation. If unspecified, it defaults to value recommended by the implementation, inferred from
  412. the value of `factor`.
  413. * `factor` (`float`, *optional*): The scaling factor to apply to the RoPE embeddings. If both
  414. `max_position_embeddings` and `original_max_position_embeddings` are provided, this value will be
  415. overridden s the ratio between those values.
  416. * `long_factor` (`float`, *optional*): The scale factor applied when computing the inverse
  417. frequencies if `seq_len` is provided and greater than `original_max_position_embeddings`.
  418. * `short_factor` (`float`, *optional*): The scale factor applied when computing the inverse
  419. frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`.
  420. Additionally, this function will make use of the following properties if they are found in the config:
  421. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  422. derived as hidden_size // num_attention_heads.
  423. * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies
  424. will be returned for the first fraction of the head_dim.
  425. device (`torch.device`):
  426. The device to use for initialization of the inverse frequencies.
  427. seq_len (`int`, *optional*):
  428. The current sequence length.
  429. Returns:
  430. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  431. post-processing scaling factor applied to the computed cos/sin.
  432. """
  433. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  434. config.standardize_rope_params()
  435. rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
  436. base = rope_parameters_dict["rope_theta"]
  437. partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
  438. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  439. dim = int(head_dim * partial_rotary_factor)
  440. long_factor = rope_parameters_dict["long_factor"]
  441. short_factor = rope_parameters_dict["short_factor"]
  442. factor = rope_parameters_dict.get("factor")
  443. attention_factor = rope_parameters_dict.get("attention_factor")
  444. original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"]
  445. # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
  446. # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
  447. # values to compute the default attention scaling factor, instead of using `factor`.
  448. if factor is None:
  449. factor = config.max_position_embeddings / original_max_position_embeddings
  450. # Sets the attention factor as suggested in the paper
  451. if attention_factor is None:
  452. if factor <= 1.0:
  453. attention_factor = 1.0
  454. else:
  455. attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
  456. # Compute the inverse frequencies -- scaled based on the target sequence length
  457. if seq_len and seq_len > original_max_position_embeddings:
  458. ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
  459. else:
  460. ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
  461. inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
  462. inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
  463. return inv_freq, attention_factor
  464. def _compute_llama3_parameters(
  465. config: "PreTrainedConfig",
  466. device: Optional["torch.device"] = None,
  467. seq_len: int | None = None,
  468. layer_type: str | None = None,
  469. ) -> tuple["torch.Tensor", float]:
  470. """
  471. Computes the inverse frequencies for llama 3.1.
  472. Args:
  473. config ([`~transformers."PreTrainedConfig"`]):
  474. The model configuration. This function assumes that the config will provide at least the following
  475. properties:
  476. * rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
  477. * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
  478. * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
  479. * rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
  480. keys will be accessed:
  481. * `factor` (`float`, *optional*): The scaling factor applied to the inverse frequencies when 1) the
  482. wavelength is greater than `low_freq_wavelen` prior to smoothing, and 2) to all inverse frequencies
  483. during smoothing.
  484. * `high_freq_factor` (`float`): The scale factor used to compute `high_freq_wavelen` and
  485. the value for the denominator of the smoothing factor prior to the `low_freq_factor` shift.
  486. * `low_freq_factor` (`float`): The scale factor used to compute `low_freq_wavelen` and
  487. the shift applied to the numerator and denominator of the smoothing factor.
  488. frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`.
  489. * `original_max_position_embeddings` (`int`): The original max position embeddings used
  490. during pretraining. If not provided, the function falls back to `max_position_embeddings`.
  491. Additionally, this function will make use of the following properties if they are found in the config:
  492. * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
  493. derived as hidden_size // num_attention_heads.
  494. * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
  495. the first fraction of the head_dim. Defaults to 1.0.
  496. device (`torch.device`):
  497. The device to use for initialization of the inverse frequencies.
  498. seq_len (`int`, *optional*):
  499. The current sequence length. Unused for this type of RoPE.
  500. Returns:
  501. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  502. post-processing scaling factor applied to the computed cos/sin.
  503. """
  504. # For backward compatibility standardize the `rope_parameters_dict` if it uses old format
  505. config.standardize_rope_params()
  506. rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters
  507. # Gets the default RoPE parameters
  508. base = rope_parameters_dict["rope_theta"]
  509. partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
  510. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  511. dim = int(head_dim * partial_rotary_factor)
  512. attention_factor = 1.0 # Unused in this type of RoPE
  513. # Compute the inverse frequencies
  514. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
  515. factor = rope_parameters_dict["factor"] # `8` in the original implementation
  516. low_freq_factor = rope_parameters_dict["low_freq_factor"] # `1` in the original implementation
  517. high_freq_factor = rope_parameters_dict["high_freq_factor"] # `4` in the original implementation
  518. old_context_len = rope_parameters_dict["original_max_position_embeddings"] # `8192` in the original implementation
  519. low_freq_wavelen = old_context_len / low_freq_factor
  520. high_freq_wavelen = old_context_len / high_freq_factor
  521. wavelen = 2 * math.pi / inv_freq
  522. # wavelen < high_freq_wavelen: do nothing
  523. # wavelen > low_freq_wavelen: divide by factor
  524. inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
  525. # otherwise: interpolate between the two, using a smooth factor
  526. smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
  527. smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
  528. is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
  529. inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
  530. return inv_freq_llama, attention_factor
  531. # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
  532. # from the model config. You can append new {'rope_type': callable} pairs to this rope_parameters to enable custom RoPE
  533. # parameterizations, as long as the callable has the same signature.
  534. ROPE_INIT_FUNCTIONS: dict[str, Callable[..., tuple["torch.Tensor", float]]] = {
  535. "linear": _compute_linear_scaling_rope_parameters,
  536. "dynamic": _compute_dynamic_ntk_parameters,
  537. "yarn": _compute_yarn_parameters,
  538. "longrope": _compute_longrope_parameters,
  539. "llama3": _compute_llama3_parameters,
  540. "proportional": _compute_proportional_rope_parameters,
  541. }
  542. class RopeParameters(TypedDict):
  543. """
  544. Args:
  545. rope_theta (`float`, *optional*, defaults to `RotaryEmbeddingConfigMixin.default_theta`):
  546. The base period of the RoPE embeddings. Optional in serialized configs — if omitted,
  547. the model's `default_theta` (typically 10000.0) is used.
  548. rope_type (`str`, *optional*, defaults to "default"):
  549. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  550. 'llama3'], with 'default' being the original RoPE implementation.
  551. partial_rotary_factor (`float`, *optional*):
  552. The percentage of the query and key head embedding on which RoPE will be applied.
  553. factor (`float`, *optional*):
  554. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  555. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  556. original maximum pre-trained length.
  557. original_max_position_embeddings (`int`, *optional*):
  558. Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during
  559. pretraining.
  560. attention_factor (`float`, *optional*):
  561. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  562. computation. If unspecified, it defaults to value recommended by the implementation, using the
  563. `factor` field to infer the suggested value.
  564. beta_fast (`float`, *optional*):
  565. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  566. ramp function. If unspecified, it defaults to 32.
  567. beta_slow (`float`, *optional*):
  568. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  569. ramp function. If unspecified, it defaults to 1.
  570. short_factor (`list[float]`, *optional*):
  571. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  572. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  573. size divided by the number of attention heads divided by 2
  574. long_factor (`list[float]`, *optional*):
  575. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  576. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  577. size divided by the number of attention heads divided by 2
  578. low_freq_factor (`float`, *optional*):
  579. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  580. high_freq_factor (`float`, *optional*):
  581. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  582. """
  583. rope_theta: float | None
  584. rope_type: str | None
  585. partial_rotary_factor: float | None
  586. factor: float | None
  587. original_max_position_embeddings: int | None
  588. attention_factor: float | None
  589. beta_fast: float | None
  590. beta_slow: float | None
  591. short_factor: list[float] | None
  592. long_factor: list[float] | None
  593. low_freq_factor: float | None
  594. high_freq_factor: float | None
  595. class RotaryEmbeddingConfigMixin:
  596. """
  597. A Mixin containing the functionality to standardize and validate RoPE parameters.
  598. """
  599. default_theta = 10_000.0
  600. ignore_keys_at_rope_validation = set()
  601. def convert_rope_params_to_dict(self, **kwargs):
  602. rope_scaling = kwargs.pop("rope_scaling", None)
  603. self.rope_parameters = rope_scaling or self.rope_parameters
  604. self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {}
  605. # Standardize and validate the correctness of rotary position embeddings parameters. Priority for these parameters is:
  606. # 1. Values in `rope_parameters` dict (where they should be after standardization)
  607. # 2. Values in `kwargs` (i.e. it's in config.json but not MyConfig.__init__'s args)
  608. # 3. Values in the config's attributes (i.e. it's in MyConfig.__init__'s args)
  609. # 4. Default values (i.e. not present at all but other RoPE parameters are present)
  610. rope_theta = kwargs.pop("rope_theta", getattr(self, "rope_theta", self.default_theta))
  611. self.rope_parameters.setdefault("rope_theta", rope_theta)
  612. partial_rotary_factor = kwargs.get("partial_rotary_factor", getattr(self, "partial_rotary_factor", None))
  613. if partial_rotary_factor is not None:
  614. self.rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor)
  615. self.ignore_keys_at_rope_validation = self.ignore_keys_at_rope_validation | {"partial_rotary_factor"}
  616. self.standardize_rope_params()
  617. return kwargs
  618. def standardize_rope_params(self):
  619. """
  620. Helper to standardize the config's rope params field by ensuring the params are defined for each
  621. later type. For old model the fn will duplicate a single rope param in each layer type (backward compatibility)
  622. """
  623. # Move `rope_theta` and `partial_rotary_factor` to the `rope_parameters`, if not there yet
  624. rope_theta = getattr(self, "rope_theta", None)
  625. partial_rotary_factor = getattr(self, "partial_rotary_factor", None)
  626. rope_parameters = getattr(self, "rope_parameters", None) or {}
  627. layer_types = getattr(self, "layer_types", None)
  628. # Case 0: no RoPE params defined
  629. if not (rope_parameters or rope_theta):
  630. # partial_rotary_factor without rope_theta is invalid, so we don't check for it here
  631. logger.warning("`standardize_rope_params` was called but no RoPE parameters were found.")
  632. return
  633. # Case 1: RoPE param keys do not intersect with possible `layer_types` -> one global dict
  634. elif layer_types is None or rope_parameters == {} or not set(rope_parameters.keys()).issubset(layer_types):
  635. rope_parameters.setdefault("rope_type", rope_parameters.get("type", "default"))
  636. rope_parameters.setdefault("rope_theta", rope_theta)
  637. if partial_rotary_factor is not None:
  638. rope_parameters["partial_rotary_factor"] = partial_rotary_factor
  639. # Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling
  640. if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]:
  641. if hasattr(self, "original_max_position_embeddings"):
  642. # NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field
  643. # containing the pretrained value outside rope parameters. This is an exception case where we
  644. # give priority to `self.original_max_position_embeddings
  645. self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings
  646. else:
  647. self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings)
  648. # Case 2: different RoPE for each layer -> several params as nested dict
  649. else:
  650. for layer_type in set(layer_types):
  651. rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default"))
  652. rope_parameters[layer_type].setdefault("rope_theta", rope_theta)
  653. if partial_rotary_factor is not None:
  654. rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor
  655. if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]:
  656. self.rope_parameters[layer_type].setdefault(
  657. "original_max_position_embeddings", self.max_position_embeddings
  658. )
  659. self.rope_parameters = rope_parameters
  660. def validate_rope(self: "PreTrainedConfig"):
  661. """
  662. Validate the RoPE config arguments, given a `"PreTrainedConfig"` object
  663. """
  664. # Don't validate if no rope_parameters found (`None`) or if it's an empty dict
  665. # Note that validation runs every time a new config is created, even if config is non-RoPE
  666. rope_parameters_dict = getattr(self, "rope_parameters", None)
  667. if not rope_parameters_dict:
  668. return
  669. if getattr(self, "layer_types", None) is not None and set(rope_parameters_dict.keys()).issubset(
  670. self.layer_types
  671. ):
  672. pass
  673. else:
  674. rope_parameters_dict = {"full_attention": rope_parameters_dict}
  675. for rope_parameters in rope_parameters_dict.values():
  676. rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default"))
  677. validation_fn = getattr(self, f"_validate_{rope_type}_rope_parameters", None)
  678. rope_parameters["rope_type"] = rope_type
  679. if validation_fn is not None:
  680. validation_fn(rope_parameters, ignore_keys=self.ignore_keys_at_rope_validation)
  681. else:
  682. logger.warning(
  683. f"Missing validation function in 'RotaryEmbeddingConfigMixin' for 'rope_type'='{rope_type}'"
  684. )
  685. def _validate_default_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  686. required_keys = {"rope_type"}
  687. optional_keys = {"rope_theta"}
  688. received_keys = set(rope_parameters.keys())
  689. rope_type = rope_parameters["rope_type"]
  690. self._check_received_keys(
  691. rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
  692. )
  693. def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  694. required_keys = {"rope_type", "factor"}
  695. optional_keys = {"rope_theta"}
  696. received_keys = set(rope_parameters.keys())
  697. rope_type = rope_parameters["rope_type"]
  698. self._check_received_keys(
  699. rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
  700. )
  701. factor = rope_parameters["factor"]
  702. if factor is None or not isinstance(factor, float) or factor < 1.0:
  703. logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
  704. def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  705. required_keys = {"rope_type", "factor"}
  706. optional_keys = {"rope_theta"}
  707. received_keys = set(rope_parameters.keys())
  708. rope_type = rope_parameters["rope_type"]
  709. self._check_received_keys(
  710. rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
  711. )
  712. factor = rope_parameters["factor"]
  713. if factor is None or not isinstance(factor, float) or factor < 1.0:
  714. logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
  715. def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  716. required_keys = {"rope_type", "factor", "original_max_position_embeddings"}
  717. optional_keys = {
  718. "rope_theta",
  719. "attention_factor",
  720. "beta_fast",
  721. "beta_slow",
  722. "mscale",
  723. "mscale_all_dim",
  724. "truncate",
  725. }
  726. received_keys = set(rope_parameters.keys())
  727. rope_type = rope_parameters["rope_type"]
  728. self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  729. factor = rope_parameters["factor"]
  730. if factor is None or not isinstance(factor, float) or factor < 1.0:
  731. logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
  732. attention_factor = rope_parameters.get("attention_factor")
  733. if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
  734. logger.warning(
  735. f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}"
  736. )
  737. beta_fast = rope_parameters.get("beta_fast")
  738. if beta_fast is not None and not isinstance(beta_fast, float):
  739. logger.warning(f"`rope_parameters`'s beta_fast field must be a float, got {beta_fast}")
  740. beta_slow = rope_parameters.get("beta_slow")
  741. if beta_slow is not None and not isinstance(beta_slow, float):
  742. logger.warning(f"`rope_parameters`'s beta_slow field must be a float, got {beta_slow}")
  743. if (beta_fast or 32) < (beta_slow or 1):
  744. logger.warning(
  745. f"`rope_parameters`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
  746. f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
  747. )
  748. # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
  749. # NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was
  750. # inferred from `max_position_embeddings` during standardization
  751. original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"]
  752. implicit_factor = self.max_position_embeddings / original_max_position_embeddings
  753. if implicit_factor != factor and implicit_factor != 1:
  754. logger.warning_once(
  755. f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match "
  756. "the ratio implicitly set by other parameters (implicit factor = "
  757. "post-yarn context length / pre-yarn context length = "
  758. "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = "
  759. f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
  760. "behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config."
  761. )
  762. def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  763. required_keys = {"rope_type", "short_factor", "long_factor", "original_max_position_embeddings"}
  764. optional_keys = {"rope_theta", "attention_factor", "factor"}
  765. received_keys = set(rope_parameters.keys())
  766. rope_type = rope_parameters["rope_type"]
  767. self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  768. partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
  769. head_dim = getattr(self, "head_dim", self.hidden_size // self.num_attention_heads)
  770. dim = int(head_dim * partial_rotary_factor)
  771. short_factor = rope_parameters.get("short_factor")
  772. if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
  773. logger.warning(f"`rope_parameters`'s short_factor field must be a list of numbers, got {short_factor}")
  774. if len(short_factor) != dim // 2:
  775. logger.warning(
  776. f"`rope_parameters`'s short_factor field must have length {dim // 2}, got {len(short_factor)}"
  777. )
  778. long_factor = rope_parameters.get("long_factor")
  779. if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
  780. logger.warning(f"`rope_parameters`'s long_factor field must be a list of numbers, got {long_factor}")
  781. if len(long_factor) != dim // 2:
  782. logger.warning(
  783. f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}"
  784. )
  785. factor = rope_parameters.get("factor")
  786. original_max_position_embeddings = rope_parameters["original_max_position_embeddings"]
  787. # Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over
  788. # `original_max_position_embeddings` to compute internal variables. The latter is undesirable
  789. if factor is None and original_max_position_embeddings is not None:
  790. logger.warning_once(
  791. "This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with "
  792. "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`"
  793. "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
  794. "as it is compatible with most model architectures."
  795. )
  796. elif factor is None and original_max_position_embeddings is None:
  797. logger.warning("Missing required keys in `rope_parameters`: 'factor'")
  798. elif not isinstance(factor, float) or factor < 1.0:
  799. logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
  800. attention_factor = rope_parameters.get("attention_factor")
  801. if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0.0):
  802. logger.warning(
  803. f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}"
  804. )
  805. def _validate_llama3_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  806. required_keys = {
  807. "rope_type",
  808. "factor",
  809. "original_max_position_embeddings",
  810. "low_freq_factor",
  811. "high_freq_factor",
  812. "rope_theta",
  813. }
  814. rope_type = rope_parameters["rope_type"]
  815. received_keys = set(rope_parameters.keys())
  816. self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  817. factor = rope_parameters["factor"]
  818. if factor is None or not isinstance(factor, float) or factor < 1.0:
  819. logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")
  820. low_freq_factor = rope_parameters["low_freq_factor"]
  821. high_freq_factor = rope_parameters["high_freq_factor"]
  822. if low_freq_factor is None or not isinstance(low_freq_factor, float):
  823. logger.warning(f"`rope_parameters`'s low_freq_factor field must be a float, got {low_freq_factor}")
  824. if high_freq_factor is None or not isinstance(high_freq_factor, float):
  825. logger.warning(f"`rope_parameters`'s high_freq_factor field must be a float, got {high_freq_factor}")
  826. if high_freq_factor <= low_freq_factor:
  827. logger.warning(
  828. "`rope_parameters`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
  829. f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
  830. )
  831. original_max_position_embeddings = rope_parameters["original_max_position_embeddings"]
  832. if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
  833. logger.warning(
  834. "`rope_parameters`'s original_max_position_embeddings field must be an integer, got "
  835. f"{original_max_position_embeddings}"
  836. )
  837. if original_max_position_embeddings >= self.max_position_embeddings:
  838. logger.warning(
  839. "`rope_parameters`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
  840. f"{original_max_position_embeddings} and max_position_embeddings={self.max_position_embeddings}"
  841. )
  842. def _validate_proportional_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
  843. required_keys = {"rope_type", "rope_theta"}
  844. rope_type = rope_parameters["rope_type"]
  845. received_keys = set(rope_parameters.keys())
  846. self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  847. partial_rotary_factor = rope_parameters.get("partial_rotary_factor")
  848. if partial_rotary_factor is None:
  849. logger.warning(
  850. "`rope_parameters`'s partial_rotary_factor is None. This will default to 1.0 in the computation, "
  851. "making this equivalent to the linear_scaling RoPE type. Provide a value in the range [0.0, 1.0) to "
  852. "make use of the proportional RoPE funcitonality."
  853. )
  854. @staticmethod
  855. def _check_received_keys(
  856. rope_type: str,
  857. received_keys: set,
  858. required_keys: set,
  859. optional_keys: set | None = None,
  860. ignore_keys: set | None = None,
  861. ):
  862. """Compare the received keys in `config.rope_parameters` against the expected and optional keys"""
  863. # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
  864. if "type" in received_keys:
  865. received_keys -= {"type"}
  866. required_keys.add("rope_type")
  867. optional_keys = optional_keys or set()
  868. if "partial_rotary_factor" not in optional_keys:
  869. optional_keys.add("partial_rotary_factor")
  870. # Some models need to store model-specific keys, and we don't want to throw warning at them
  871. if ignore_keys is not None:
  872. received_keys -= set(ignore_keys)
  873. missing_keys = required_keys - received_keys
  874. if missing_keys:
  875. raise KeyError(f"Missing required keys in `rope_parameters` for 'rope_type'='{rope_type}': {missing_keys}")
  876. unused_keys = received_keys - required_keys - optional_keys
  877. if unused_keys:
  878. logger.warning(f"Unrecognized keys in `rope_parameters` for 'rope_type'='{rope_type}': {unused_keys}")
  879. def rope_config_validation(config: RotaryEmbeddingConfigMixin, ignore_keys: set | None = None):
  880. """
  881. This is a deprecated function.
  882. It has been kept for backward compatibility with custom code models.
  883. """
  884. warnings.warn(
  885. "`rope_config_validation` is deprecated and has been removed. "
  886. "Its functionality has been moved to RotaryEmbeddingConfigMixin.validate_rope method. "
  887. "PreTrainedConfig inherits this class, so please call self.validate_rope() instead. "
  888. "Also, make sure to use the new rope_parameters syntax. "
  889. "You can call self.standardize_rope_params() in the meantime.",
  890. FutureWarning,
  891. )
  892. config.standardize_rope_params()
  893. config.validate_rope()