modular_minimax.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MiniMax model."""
  16. import torch
  17. import torch.nn.functional as F
  18. from huggingface_hub.dataclasses import strict
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...configuration_utils import PreTrainedConfig
  24. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import MoeModelOutputWithPast
  28. from ...modeling_rope_utils import RopeParameters
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, auto_docstring, logging
  31. from ...utils.generic import merge_with_config_defaults
  32. from ...utils.output_capturing import OutputRecorder, capture_outputs
  33. from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding
  34. from ..mixtral.modeling_mixtral import (
  35. MixtralAttention,
  36. MixtralDecoderLayer,
  37. MixtralForCausalLM,
  38. MixtralForQuestionAnswering,
  39. MixtralForSequenceClassification,
  40. MixtralForTokenClassification,
  41. MixtralModel,
  42. MixtralPreTrainedModel,
  43. MixtralRMSNorm,
  44. MixtralSparseMoeBlock,
  45. MixtralTopKRouter,
  46. )
  47. logger = logging.get_logger(__name__)
  48. @auto_docstring(checkpoint="MiniMaxAI/MiniMax-Text-01-hf")
  49. @strict
  50. class MiniMaxConfig(PreTrainedConfig):
  51. r"""
  52. block_size (`int`, *optional*, defaults to 256):
  53. The length of each attention block, determining how queries, keys, and values
  54. are grouped and processed for intra- and inter-block attention.
  55. full_attn_alpha_factor (`float`, *optional*, defaults to 1):
  56. Weight for residual value in residual connection after normal attention.
  57. full_attn_beta_factor (`float`, *optional*, defaults to 1):
  58. Weight for hidden state value in residual connection after normal attention.
  59. linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
  60. Weight for residual value in residual connection after lightning attention.
  61. linear_attn_beta_factor (`float`, *optional*, defaults to 1):
  62. Weight for hidden state value in residual connection after lightning attention.
  63. mlp_alpha_factor (`float`, *optional*, defaults to 1):
  64. Weight for residual value in residual connection after MLP.
  65. mlp_beta_factor (`float`, *optional*, defaults to 1):
  66. Weight for hidden state value in residual connection after MLP.
  67. ```python
  68. >>> from transformers import MiniMaxModel, MiniMaxConfig
  69. >>> # Initializing a MiniMax style configuration
  70. >>> configuration = MiniMaxConfig()
  71. >>> # Initializing a model from the MiniMax style configuration
  72. >>> model = MiniMaxModel(configuration)
  73. >>> # Accessing the model configuration
  74. >>> configuration = model.config
  75. ```"""
  76. model_type = "minimax"
  77. keys_to_ignore_at_inference = ["past_key_values"]
  78. default_theta = 1000000.0
  79. base_model_tp_plan = {
  80. "layers.*.self_attn.q_proj": "colwise",
  81. "layers.*.self_attn.k_proj": "colwise",
  82. "layers.*.self_attn.v_proj": "colwise",
  83. "layers.*.self_attn.o_proj": "rowwise",
  84. "layers.*.mlp.experts.gate_up_proj": "packed_colwise",
  85. "layers.*.mlp.experts.down_proj": "rowwise",
  86. "layers.*.mlp.experts": "moe_tp_experts",
  87. }
  88. base_model_pp_plan = {
  89. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  90. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  91. "norm": (["hidden_states"], ["hidden_states"]),
  92. }
  93. attribute_map = {"num_experts": "num_local_experts"}
  94. vocab_size: int = 32000
  95. hidden_size: int = 4096
  96. intermediate_size: int = 14336
  97. num_hidden_layers: int = 32
  98. num_attention_heads: int = 32
  99. num_key_value_heads: int = 8
  100. head_dim: int | None = None
  101. hidden_act: str = "silu"
  102. max_position_embeddings: int = 4096 * 32
  103. initializer_range: float = 0.02
  104. rms_norm_eps: float = 1e-5
  105. use_cache: bool = True
  106. pad_token_id: int | None = None
  107. bos_token_id: int | None = 1
  108. eos_token_id: int | list[int] | None = 2
  109. tie_word_embeddings: bool = False
  110. sliding_window: int | None = None
  111. attention_dropout: float | int = 0.0
  112. num_experts_per_tok: int = 2
  113. num_local_experts: int = 8
  114. output_router_logits: bool = False
  115. router_aux_loss_coef: float = 0.001
  116. router_jitter_noise: float = 0.0
  117. rope_parameters: RopeParameters | dict | None = None
  118. layer_types: list[str] | None = None
  119. block_size: int = 256
  120. full_attn_alpha_factor: int | float = 1
  121. full_attn_beta_factor: int | float = 1
  122. linear_attn_alpha_factor: int | float = 1
  123. linear_attn_beta_factor: int | float = 1
  124. mlp_alpha_factor: int | float = 1
  125. mlp_beta_factor: int | float = 1
  126. def __post_init__(self, **kwargs):
  127. if self.num_key_value_heads is None:
  128. self.num_key_value_heads = self.num_attention_heads
  129. if self.layer_types is None:
  130. self.layer_types = [
  131. "full_attention" if bool((i + 1) % 2) else "linear_attention" for i in range(self.num_hidden_layers)
  132. ]
  133. super().__post_init__(**kwargs)
  134. class MiniMaxRMSNorm(MixtralRMSNorm):
  135. pass
  136. class MiniMaxCache(DynamicCache):
  137. def __init__(self):
  138. super().__init__()
  139. self.linear_cache: list[torch.Tensor] = []
  140. def set_linear_cache(self, layer_idx, linear_cache):
  141. # There may be skipped layers, fill them with empty lists
  142. for _ in range(len(self.linear_cache), layer_idx + 1):
  143. self.linear_cache.append([])
  144. self.linear_cache[layer_idx] = linear_cache
  145. def get_linear_cache(self, layer_idx: int):
  146. if layer_idx < len(self):
  147. return self.linear_cache[layer_idx]
  148. return None
  149. def __len__(self):
  150. return max(super().__len__(), len(self.linear_cache))
  151. def batch_repeat_interleave(self, repeats: int):
  152. for layer_idx in range(len(self)):
  153. if self.linear_cache[layer_idx] != []:
  154. self.linear_cache[layer_idx] = self.linear_cache[layer_idx].repeat_interleave(repeats, dim=0)
  155. else:
  156. self.layers[layer_idx].batch_repeat_interleave(repeats)
  157. def batch_select_indices(self, indices: torch.Tensor):
  158. for layer_idx in range(len(self)):
  159. if self.linear_cache[layer_idx] != []:
  160. self.linear_cache[layer_idx] = self.linear_cache[layer_idx][indices, ...]
  161. else:
  162. self.layers[layer_idx].batch_select_indices(indices)
  163. def crop(self, max_length: int):
  164. raise RuntimeError("MiniMaxCache doesnot support `crop` method")
  165. class MiniMaxLightningAttention(nn.Module):
  166. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  167. super().__init__()
  168. self.layer_idx = layer_idx
  169. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  170. self.num_attention_heads = config.num_attention_heads
  171. self.num_hidden_layers = config.num_hidden_layers
  172. self.block_size = config.block_size
  173. self.act_fn = ACT2FN[config.hidden_act]
  174. self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads)
  175. self.qkv_proj = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim * 3, bias=False)
  176. self.out_proj = nn.Linear(self.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  177. self.output_gate = nn.Linear(config.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  178. slope_rate = self.get_slope_rate()
  179. query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
  180. self.register_buffer("slope_rate", slope_rate)
  181. self.register_buffer("query_decay", query_decay)
  182. self.register_buffer("key_decay", key_decay)
  183. self.register_buffer("diagonal_decay", diagonal_decay)
  184. def get_slope_rate(self):
  185. base = 1 / (2 ** (8 / self.num_attention_heads))
  186. exponent = torch.arange(self.num_attention_heads) + 1
  187. factor = 1 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
  188. rate = base**exponent
  189. rate = rate * factor
  190. rate = rate[:, None, None]
  191. return rate
  192. def decay_factors(self, slope_rate):
  193. block_size_range = torch.arange(self.block_size) + 1
  194. query_decay = torch.exp(-slope_rate * block_size_range[:, None])
  195. key_decay = torch.exp(-slope_rate * (self.block_size - block_size_range[:, None]))
  196. diagonal_decay = block_size_range[:, None] - block_size_range[None, :]
  197. diagonal_decay = diagonal_decay[None, None, :, :]
  198. diagonal_decay = slope_rate * diagonal_decay
  199. diagonal_decay = torch.where(diagonal_decay >= 0, -diagonal_decay, float("-inf"))
  200. diagonal_decay = torch.exp(diagonal_decay)
  201. return query_decay, key_decay, diagonal_decay
  202. def forward(
  203. self,
  204. hidden_states: torch.Tensor,
  205. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  206. attention_mask: torch.Tensor | None,
  207. past_key_values: Cache | None = None,
  208. **kwargs: Unpack[FlashAttentionKwargs],
  209. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  210. batch_size, seq_len, hidden_size = hidden_states.shape
  211. num_blocks = (seq_len + self.block_size - 1) // self.block_size
  212. qkv_states = self.act_fn(self.qkv_proj(hidden_states))
  213. qkv_states = qkv_states.reshape(batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim)
  214. query_states, key_states, value_states = torch.split(qkv_states, self.head_dim, dim=3)
  215. query_states = query_states.transpose(1, 2)
  216. key_states = key_states.transpose(1, 2)
  217. value_states = value_states.transpose(1, 2)
  218. # calculated (K.T @ V) and saved as cache
  219. attn_weights_inter = None
  220. if past_key_values is not None:
  221. attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
  222. if attn_weights_inter is None:
  223. attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to(
  224. value_states
  225. )
  226. # apply attention_mask
  227. if attention_mask is not None:
  228. attention_mask = attention_mask.to(dtype=torch.bool) # Ensure it's a boolean tensor
  229. value_states = value_states.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(-1), 0)
  230. attn_output = []
  231. for i in range(num_blocks):
  232. start_idx = i * self.block_size
  233. end_idx = min(start_idx + self.block_size, seq_len)
  234. current_block_size = end_idx - start_idx
  235. current_query_states = query_states[:, :, start_idx:end_idx]
  236. current_key_states = key_states[:, :, start_idx:end_idx]
  237. current_value_states = value_states[:, :, start_idx:end_idx]
  238. current_query_decay = self.query_decay[:, :current_block_size]
  239. current_key_decay = self.key_decay[:, -current_block_size:]
  240. current_diagonal_decay = self.diagonal_decay[:, :, :current_block_size, :current_block_size]
  241. block_decay = torch.exp(-self.slope_rate * current_block_size)
  242. # intra: ( Q @ K.T ) @ V -> QK * V
  243. attn_weights_intra = torch.matmul(current_query_states, current_key_states.transpose(-1, -2))
  244. attn_output_intra = torch.matmul(attn_weights_intra * current_diagonal_decay, current_value_states)
  245. # inter: Q @ ( K.T @ V ) -> Q * KV
  246. attn_output_inter = torch.matmul(current_query_states * current_query_decay, attn_weights_inter)
  247. # final attention output
  248. current_attn_output = attn_output_inter + attn_output_intra
  249. attn_output.append(current_attn_output)
  250. # calculate attn_weights_inter for next block or cache
  251. next_attn_weights_inter = torch.matmul(
  252. (current_key_states * current_key_decay).transpose(-1, -2), current_value_states
  253. )
  254. attn_weights_inter = attn_weights_inter * block_decay + next_attn_weights_inter
  255. else:
  256. ratio = torch.exp(-self.slope_rate)
  257. attn_output = []
  258. for i in range(seq_len):
  259. current_query_states = query_states[:, :, i : i + 1]
  260. current_key_states = key_states[:, :, i : i + 1]
  261. current_value_states = value_states[:, :, i : i + 1]
  262. current_attn_weights_inter = torch.matmul(current_key_states.transpose(-1, -2), current_value_states)
  263. attn_weights_inter = ratio * attn_weights_inter + current_attn_weights_inter
  264. current_attn_output = torch.matmul(current_query_states, attn_weights_inter)
  265. attn_output.append(current_attn_output)
  266. # concatenate attention outputs over all blocks
  267. attn_output = torch.cat(attn_output, dim=-2)
  268. # final output projection
  269. attn_output = attn_output.transpose(1, 2)
  270. attn_output = attn_output.reshape(batch_size, seq_len, self.num_attention_heads * self.head_dim)
  271. attn_output = self.norm(attn_output)
  272. attn_output = F.sigmoid(self.output_gate(hidden_states)) * attn_output
  273. attn_output = self.out_proj(attn_output)
  274. # update cache
  275. if past_key_values is not None:
  276. past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
  277. return attn_output, attn_weights_inter
  278. class MiniMaxRotaryEmbedding(Gemma2RotaryEmbedding):
  279. pass
  280. class MiniMaxAttention(MixtralAttention):
  281. pass
  282. class MiniMaxTopKRouter(MixtralTopKRouter):
  283. pass
  284. class MiniMaxSparseMoeBlock(MixtralSparseMoeBlock):
  285. pass
  286. class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer):
  287. def __init__(self, config: MiniMaxConfig, layer_idx: int):
  288. super().__init__(config, layer_idx)
  289. self.layer_idx = layer_idx
  290. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  291. self.mlp_alpha_factor = config.mlp_alpha_factor
  292. self.mlp_beta_factor = config.mlp_beta_factor
  293. del self.mlp
  294. self.mlp = MiniMaxSparseMoeBlock(config)
  295. if self.layer_type == "linear_attention":
  296. self.self_attn = MiniMaxLightningAttention(config, layer_idx)
  297. self.attn_alpha_factor = config.linear_attn_alpha_factor
  298. self.attn_beta_factor = config.linear_attn_beta_factor
  299. else:
  300. self.self_attn = MiniMaxAttention(config, layer_idx)
  301. self.attn_alpha_factor = config.full_attn_alpha_factor
  302. self.attn_beta_factor = config.full_attn_beta_factor
  303. def forward(
  304. self,
  305. hidden_states: torch.Tensor,
  306. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  307. attention_mask: torch.Tensor | None = None,
  308. position_ids: torch.LongTensor | None = None,
  309. past_key_values: Cache | None = None,
  310. use_cache: bool | None = False,
  311. **kwargs: Unpack[FlashAttentionKwargs],
  312. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  313. hidden_states = self.input_layernorm(hidden_states)
  314. residual = hidden_states
  315. hidden_states, _ = self.self_attn(
  316. hidden_states=hidden_states,
  317. position_embeddings=position_embeddings,
  318. attention_mask=attention_mask,
  319. position_ids=position_ids,
  320. past_key_values=past_key_values,
  321. use_cache=use_cache,
  322. **kwargs,
  323. )
  324. hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
  325. hidden_states = self.post_attention_layernorm(hidden_states)
  326. residual = hidden_states
  327. hidden_states = self.mlp(hidden_states)
  328. hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
  329. return hidden_states
  330. class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
  331. _can_compile_fullgraph = False # uses a non-compilable custom cache class MiniMaxCache
  332. _can_record_outputs = {
  333. "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0),
  334. "hidden_states": MiniMaxDecoderLayer,
  335. "attentions": [MiniMaxAttention, MiniMaxLightningAttention],
  336. }
  337. def _init_weights(self, module):
  338. super()._init_weights(module)
  339. if isinstance(module, MiniMaxLightningAttention):
  340. slope_rate = module.get_slope_rate()
  341. query_decay, key_decay, diagonal_decay = module.decay_factors(slope_rate)
  342. init.copy_(module.slope_rate, slope_rate)
  343. init.copy_(module.query_decay, query_decay)
  344. init.copy_(module.key_decay, key_decay)
  345. init.copy_(module.diagonal_decay, diagonal_decay)
  346. class MiniMaxModel(MixtralModel):
  347. @merge_with_config_defaults
  348. @capture_outputs
  349. def forward(
  350. self,
  351. input_ids: torch.LongTensor | None = None,
  352. attention_mask: torch.Tensor | None = None,
  353. position_ids: torch.LongTensor | None = None,
  354. past_key_values: MiniMaxCache | None = None,
  355. inputs_embeds: torch.FloatTensor | None = None,
  356. use_cache: bool | None = None,
  357. **kwargs: Unpack[TransformersKwargs],
  358. ) -> tuple | MoeModelOutputWithPast:
  359. if (input_ids is None) ^ (inputs_embeds is not None):
  360. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  361. if use_cache and past_key_values is None:
  362. past_key_values = MiniMaxCache()
  363. elif use_cache and not isinstance(past_key_values, MiniMaxCache):
  364. raise ValueError(
  365. f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
  366. )
  367. if inputs_embeds is None:
  368. inputs_embeds = self.embed_tokens(input_ids)
  369. if position_ids is None:
  370. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  371. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  372. position_ids = position_ids.unsqueeze(0)
  373. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  374. causal_mask = mask_function(
  375. config=self.config,
  376. inputs_embeds=inputs_embeds,
  377. attention_mask=attention_mask,
  378. past_key_values=past_key_values,
  379. position_ids=position_ids,
  380. )
  381. hidden_states = inputs_embeds
  382. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  383. for i, decoder_layer in enumerate(self.layers):
  384. if self.config.layer_types[i] == "full_attention":
  385. input_attention_mask = causal_mask
  386. else:
  387. # lightning attention uses original attention_mask, and uses it only for the first step
  388. input_attention_mask = attention_mask
  389. hidden_states = decoder_layer(
  390. hidden_states,
  391. attention_mask=input_attention_mask,
  392. position_embeddings=position_embeddings,
  393. position_ids=position_ids,
  394. past_key_values=past_key_values,
  395. use_cache=use_cache,
  396. **kwargs,
  397. )
  398. hidden_states = self.norm(hidden_states)
  399. return MoeModelOutputWithPast(
  400. last_hidden_state=hidden_states,
  401. past_key_values=past_key_values,
  402. )
  403. class MiniMaxForCausalLM(MixtralForCausalLM):
  404. def forward(self, **super_kwargs):
  405. r"""
  406. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  407. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  408. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  409. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  410. Example:
  411. ```python
  412. >>> from transformers import AutoTokenizer, MiniMaxForCausalLM
  413. >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
  414. >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
  415. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  416. >>> inputs = tokenizer(prompt, return_tensors="pt")
  417. >>> # Generate
  418. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  419. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  420. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  421. ```"""
  422. return super().forward(**super_kwargs)
  423. class MiniMaxForSequenceClassification(MixtralForSequenceClassification):
  424. pass
  425. class MiniMaxForTokenClassification(MixtralForTokenClassification):
  426. pass
  427. class MiniMaxForQuestionAnswering(MixtralForQuestionAnswering):
  428. pass
  429. __all__ = [
  430. "MiniMaxConfig",
  431. "MiniMaxPreTrainedModel",
  432. "MiniMaxModel",
  433. "MiniMaxForCausalLM",
  434. "MiniMaxForSequenceClassification",
  435. "MiniMaxForTokenClassification",
  436. "MiniMaxForQuestionAnswering",
  437. ]