modeling_arcee.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/arcee/modular_arcee.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_arcee.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from transformers.utils import auto_docstring
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import (
  31. GenericForQuestionAnswering,
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_arcee import ArceeConfig
  44. class ArceeMLP(nn.Module):
  45. def __init__(self, config):
  46. super().__init__()
  47. self.config = config
  48. self.hidden_size = config.hidden_size
  49. self.intermediate_size = config.intermediate_size
  50. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  51. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  52. self.act_fn = ACT2FN[config.hidden_act]
  53. def forward(self, x):
  54. return self.down_proj(self.act_fn(self.up_proj(x)))
  55. @use_kernel_forward_from_hub("RMSNorm")
  56. class ArceeRMSNorm(nn.Module):
  57. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  58. """
  59. ArceeRMSNorm is equivalent to T5LayerNorm
  60. """
  61. super().__init__()
  62. self.weight = nn.Parameter(torch.ones(hidden_size))
  63. self.variance_epsilon = eps
  64. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  65. input_dtype = hidden_states.dtype
  66. hidden_states = hidden_states.to(torch.float32)
  67. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  68. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  69. return self.weight * hidden_states.to(input_dtype)
  70. def extra_repr(self):
  71. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  72. class ArceeRotaryEmbedding(nn.Module):
  73. inv_freq: torch.Tensor # fix linting for `register_buffer`
  74. def __init__(self, config: ArceeConfig, device=None):
  75. super().__init__()
  76. self.max_seq_len_cached = config.max_position_embeddings
  77. self.original_max_seq_len = config.max_position_embeddings
  78. self.config = config
  79. self.rope_type = self.config.rope_parameters["rope_type"]
  80. rope_init_fn: Callable = self.compute_default_rope_parameters
  81. if self.rope_type != "default":
  82. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  83. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  84. self.register_buffer("inv_freq", inv_freq, persistent=False)
  85. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  86. @staticmethod
  87. def compute_default_rope_parameters(
  88. config: ArceeConfig | None = None,
  89. device: Optional["torch.device"] = None,
  90. seq_len: int | None = None,
  91. ) -> tuple["torch.Tensor", float]:
  92. """
  93. Computes the inverse frequencies according to the original RoPE implementation
  94. Args:
  95. config ([`~transformers.PreTrainedConfig`]):
  96. The model configuration.
  97. device (`torch.device`):
  98. The device to use for initialization of the inverse frequencies.
  99. seq_len (`int`, *optional*):
  100. The current sequence length. Unused for this type of RoPE.
  101. Returns:
  102. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  103. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  104. """
  105. base = config.rope_parameters["rope_theta"]
  106. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  107. attention_factor = 1.0 # Unused in this type of RoPE
  108. # Compute the inverse frequencies
  109. inv_freq = 1.0 / (
  110. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  111. )
  112. return inv_freq, attention_factor
  113. @torch.no_grad()
  114. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  115. def forward(self, x, position_ids):
  116. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  117. position_ids_expanded = position_ids[:, None, :].float()
  118. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  119. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  120. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  121. emb = torch.cat((freqs, freqs), dim=-1)
  122. cos = emb.cos() * self.attention_scaling
  123. sin = emb.sin() * self.attention_scaling
  124. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  125. def rotate_half(x):
  126. """Rotates half the hidden dims of the input."""
  127. x1 = x[..., : x.shape[-1] // 2]
  128. x2 = x[..., x.shape[-1] // 2 :]
  129. return torch.cat((-x2, x1), dim=-1)
  130. @use_kernel_func_from_hub("rotary_pos_emb")
  131. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  132. """Applies Rotary Position Embedding to the query and key tensors.
  133. Args:
  134. q (`torch.Tensor`): The query tensor.
  135. k (`torch.Tensor`): The key tensor.
  136. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  137. sin (`torch.Tensor`): The sine part of the rotary embedding.
  138. unsqueeze_dim (`int`, *optional*, defaults to 1):
  139. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  140. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  141. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  142. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  143. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  144. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  145. Returns:
  146. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  147. """
  148. cos = cos.unsqueeze(unsqueeze_dim)
  149. sin = sin.unsqueeze(unsqueeze_dim)
  150. q_embed = (q * cos) + (rotate_half(q) * sin)
  151. k_embed = (k * cos) + (rotate_half(k) * sin)
  152. return q_embed, k_embed
  153. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  154. """
  155. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  156. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  157. """
  158. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  159. if n_rep == 1:
  160. return hidden_states
  161. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  162. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  163. def eager_attention_forward(
  164. module: nn.Module,
  165. query: torch.Tensor,
  166. key: torch.Tensor,
  167. value: torch.Tensor,
  168. attention_mask: torch.Tensor | None,
  169. scaling: float,
  170. dropout: float = 0.0,
  171. **kwargs: Unpack[TransformersKwargs],
  172. ):
  173. key_states = repeat_kv(key, module.num_key_value_groups)
  174. value_states = repeat_kv(value, module.num_key_value_groups)
  175. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  176. if attention_mask is not None:
  177. attn_weights = attn_weights + attention_mask
  178. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  179. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  180. attn_output = torch.matmul(attn_weights, value_states)
  181. attn_output = attn_output.transpose(1, 2).contiguous()
  182. return attn_output, attn_weights
  183. @use_kernelized_func(apply_rotary_pos_emb)
  184. class ArceeAttention(nn.Module):
  185. """Multi-headed attention from 'Attention Is All You Need' paper"""
  186. def __init__(self, config: ArceeConfig, layer_idx: int):
  187. super().__init__()
  188. self.config = config
  189. self.layer_idx = layer_idx
  190. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  191. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  192. self.scaling = self.head_dim**-0.5
  193. self.attention_dropout = config.attention_dropout
  194. self.is_causal = True
  195. self.q_proj = nn.Linear(
  196. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  197. )
  198. self.k_proj = nn.Linear(
  199. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  200. )
  201. self.v_proj = nn.Linear(
  202. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  203. )
  204. self.o_proj = nn.Linear(
  205. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  206. )
  207. def forward(
  208. self,
  209. hidden_states: torch.Tensor,
  210. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  211. attention_mask: torch.Tensor | None = None,
  212. past_key_values: Cache | None = None,
  213. **kwargs: Unpack[TransformersKwargs],
  214. ) -> tuple[torch.Tensor, torch.Tensor]:
  215. input_shape = hidden_states.shape[:-1]
  216. hidden_shape = (*input_shape, -1, self.head_dim)
  217. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  218. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  219. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  220. cos, sin = position_embeddings
  221. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  222. if past_key_values is not None:
  223. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  224. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  225. self.config._attn_implementation, eager_attention_forward
  226. )
  227. attn_output, attn_weights = attention_interface(
  228. self,
  229. query_states,
  230. key_states,
  231. value_states,
  232. attention_mask,
  233. dropout=0.0 if not self.training else self.attention_dropout,
  234. scaling=self.scaling,
  235. **kwargs,
  236. )
  237. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  238. attn_output = self.o_proj(attn_output)
  239. return attn_output, attn_weights
  240. class ArceeDecoderLayer(GradientCheckpointingLayer):
  241. def __init__(self, config: ArceeConfig, layer_idx: int):
  242. super().__init__()
  243. self.hidden_size = config.hidden_size
  244. self.self_attn = ArceeAttention(config=config, layer_idx=layer_idx)
  245. self.mlp = ArceeMLP(config)
  246. self.input_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  247. self.post_attention_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  248. def forward(
  249. self,
  250. hidden_states: torch.Tensor,
  251. attention_mask: torch.Tensor | None = None,
  252. position_ids: torch.LongTensor | None = None,
  253. past_key_values: Cache | None = None,
  254. use_cache: bool | None = False,
  255. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  256. **kwargs: Unpack[TransformersKwargs],
  257. ) -> torch.Tensor:
  258. residual = hidden_states
  259. hidden_states = self.input_layernorm(hidden_states)
  260. # Self Attention
  261. hidden_states, _ = self.self_attn(
  262. hidden_states=hidden_states,
  263. attention_mask=attention_mask,
  264. position_ids=position_ids,
  265. past_key_values=past_key_values,
  266. use_cache=use_cache,
  267. position_embeddings=position_embeddings,
  268. **kwargs,
  269. )
  270. hidden_states = residual + hidden_states
  271. # Fully Connected
  272. residual = hidden_states
  273. hidden_states = self.post_attention_layernorm(hidden_states)
  274. hidden_states = self.mlp(hidden_states)
  275. hidden_states = residual + hidden_states
  276. return hidden_states
  277. @auto_docstring
  278. class ArceePreTrainedModel(PreTrainedModel):
  279. config: ArceeConfig
  280. base_model_prefix = "model"
  281. supports_gradient_checkpointing = True
  282. _no_split_modules = ["ArceeDecoderLayer"]
  283. _skip_keys_device_placement = ["past_key_values"]
  284. _supports_flash_attn = True
  285. _supports_sdpa = True
  286. _supports_flex_attn = True
  287. _can_compile_fullgraph = True
  288. _supports_attention_backend = True
  289. _can_record_outputs = {
  290. "hidden_states": ArceeDecoderLayer,
  291. "attentions": ArceeAttention,
  292. }
  293. @auto_docstring
  294. class ArceeModel(ArceePreTrainedModel):
  295. def __init__(self, config: ArceeConfig):
  296. super().__init__(config)
  297. self.padding_idx = config.pad_token_id
  298. self.vocab_size = config.vocab_size
  299. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  300. self.layers = nn.ModuleList(
  301. [ArceeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  302. )
  303. self.norm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  304. self.rotary_emb = ArceeRotaryEmbedding(config=config)
  305. self.gradient_checkpointing = False
  306. # Initialize weights and apply final processing
  307. self.post_init()
  308. @merge_with_config_defaults
  309. @capture_outputs
  310. @auto_docstring
  311. def forward(
  312. self,
  313. input_ids: torch.LongTensor | None = None,
  314. attention_mask: torch.Tensor | None = None,
  315. position_ids: torch.LongTensor | None = None,
  316. past_key_values: Cache | None = None,
  317. inputs_embeds: torch.FloatTensor | None = None,
  318. use_cache: bool | None = None,
  319. **kwargs: Unpack[TransformersKwargs],
  320. ) -> BaseModelOutputWithPast:
  321. if (input_ids is None) ^ (inputs_embeds is not None):
  322. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  323. if inputs_embeds is None:
  324. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  325. if use_cache and past_key_values is None:
  326. past_key_values = DynamicCache(config=self.config)
  327. if position_ids is None:
  328. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  329. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  330. position_ids = position_ids.unsqueeze(0)
  331. causal_mask = create_causal_mask(
  332. config=self.config,
  333. inputs_embeds=inputs_embeds,
  334. attention_mask=attention_mask,
  335. past_key_values=past_key_values,
  336. position_ids=position_ids,
  337. )
  338. hidden_states = inputs_embeds
  339. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  340. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  341. hidden_states = decoder_layer(
  342. hidden_states,
  343. attention_mask=causal_mask,
  344. position_embeddings=position_embeddings,
  345. position_ids=position_ids,
  346. past_key_values=past_key_values,
  347. use_cache=use_cache,
  348. **kwargs,
  349. )
  350. hidden_states = self.norm(hidden_states)
  351. return BaseModelOutputWithPast(
  352. last_hidden_state=hidden_states,
  353. past_key_values=past_key_values,
  354. )
  355. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  356. class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
  357. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  358. _tp_plan = {"lm_head": "colwise_gather_output"}
  359. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  360. def __init__(self, config):
  361. super().__init__(config)
  362. self.model = ArceeModel(config)
  363. self.vocab_size = config.vocab_size
  364. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  365. # Initialize weights and apply final processing
  366. self.post_init()
  367. @can_return_tuple
  368. @auto_docstring
  369. def forward(
  370. self,
  371. input_ids: torch.LongTensor | None = None,
  372. attention_mask: torch.Tensor | None = None,
  373. position_ids: torch.LongTensor | None = None,
  374. past_key_values: Cache | None = None,
  375. inputs_embeds: torch.FloatTensor | None = None,
  376. labels: torch.LongTensor | None = None,
  377. use_cache: bool | None = None,
  378. logits_to_keep: int | torch.Tensor = 0,
  379. **kwargs: Unpack[TransformersKwargs],
  380. ) -> CausalLMOutputWithPast:
  381. r"""
  382. Example:
  383. ```python
  384. >>> from transformers import AutoTokenizer, ArceeForCausalLM
  385. >>> model = ArceeForCausalLM.from_pretrained("meta-arcee/Arcee-2-7b-hf")
  386. >>> tokenizer = AutoTokenizer.from_pretrained("meta-arcee/Arcee-2-7b-hf")
  387. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  388. >>> inputs = tokenizer(prompt, return_tensors="pt")
  389. >>> # Generate
  390. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  391. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  392. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  393. ```"""
  394. outputs: BaseModelOutputWithPast = self.model(
  395. input_ids=input_ids,
  396. attention_mask=attention_mask,
  397. position_ids=position_ids,
  398. past_key_values=past_key_values,
  399. inputs_embeds=inputs_embeds,
  400. use_cache=use_cache,
  401. **kwargs,
  402. )
  403. hidden_states = outputs.last_hidden_state
  404. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  405. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  406. logits = self.lm_head(hidden_states[:, slice_indices, :])
  407. loss = None
  408. if labels is not None:
  409. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  410. return CausalLMOutputWithPast(
  411. loss=loss,
  412. logits=logits,
  413. past_key_values=outputs.past_key_values,
  414. hidden_states=outputs.hidden_states,
  415. attentions=outputs.attentions,
  416. )
  417. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  418. class ArceeForSequenceClassification(GenericForSequenceClassification, ArceePreTrainedModel):
  419. pass
  420. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  421. class ArceeForQuestionAnswering(GenericForQuestionAnswering, ArceePreTrainedModel):
  422. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  423. @auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
  424. class ArceeForTokenClassification(GenericForTokenClassification, ArceePreTrainedModel):
  425. pass
  426. __all__ = [
  427. "ArceeForCausalLM",
  428. "ArceeForQuestionAnswering",
  429. "ArceeForSequenceClassification",
  430. "ArceeForTokenClassification",
  431. "ArceeModel",
  432. "ArceePreTrainedModel",
  433. ]