moe.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from functools import wraps
  16. from ..utils import logging
  17. from ..utils.generic import GeneralInterface
  18. from ..utils.import_utils import is_torch_available, is_torch_less_or_equal, is_torchdynamo_compiling
  19. if is_torch_available():
  20. import torch
  21. logger = logging.get_logger(__name__)
  22. # Examples of experts class with its eager mm implementation
  23. # class Experts(torch.nn.Module):
  24. # """Collection of expert weights stored as 3D tensors."""
  25. # def __init__(self, config):
  26. # super().__init__()
  27. # self.num_experts = config.n_routed_experts
  28. # self.hidden_dim = config.hidden_size
  29. # self.intermediate_dim = config.moe_intermediate_size
  30. # self.gate_up_proj = torch.nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
  31. # self.down_proj = torch.nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
  32. # self.act_fn = ACT2FN[config.hidden_act]
  33. # def forward(
  34. # self,
  35. # hidden_states: torch.Tensor,
  36. # top_k_index: torch.Tensor,
  37. # top_k_weights: torch.Tensor,
  38. # ) -> torch.Tensor:
  39. # final_hidden_states = torch.zeros_like(hidden_states)
  40. # with torch.no_grad():
  41. # expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
  42. # expert_mask = expert_mask.permute(2, 1, 0)
  43. # expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
  44. # for expert_idx in expert_hit:
  45. # expert_idx = expert_idx[0]
  46. # if expert_idx == self.num_experts:
  47. # continue
  48. # top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
  49. # current_state = hidden_states[token_idx]
  50. # gate, up = torch.nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
  51. # current_hidden_states = self.act_fn(gate) * up
  52. # current_hidden_states = torch.nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
  53. # current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
  54. # final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
  55. # return final_hidden_states
  56. def _batched_linear(
  57. input: torch.Tensor,
  58. weight: torch.Tensor,
  59. bias: torch.Tensor | None = None,
  60. is_transposed: bool = False,
  61. ) -> torch.Tensor:
  62. """Batched linear layer supporting optional bias and transposed weights.
  63. Args:
  64. input (`torch.Tensor`):
  65. Input tensor of shape (batch_size, input_dim).
  66. weight (`torch.Tensor`):
  67. Weight tensor of shape (batch_size, output_dim, input_dim) if transposed is `False`,
  68. else of shape (batch_size, input_dim, output_dim).
  69. bias (`torch.Tensor`, *optional*):
  70. Bias tensor of shape (batch_size, output_dim). Default is `None`.
  71. is_transposed (`bool`, *optional*, defaults to `False`):
  72. Whether the weight tensor is transposed.
  73. Returns:
  74. `torch.Tensor`: Output tensor of shape (batch_size, output_dim).
  75. """
  76. if is_transposed:
  77. # (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim)
  78. out = torch.bmm(input.unsqueeze(1), weight).squeeze(1)
  79. else:
  80. # (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim)
  81. out = torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1)
  82. if bias is not None:
  83. out = out + bias
  84. return out
  85. def batched_mm_experts_forward(
  86. self: torch.nn.Module,
  87. hidden_states: torch.Tensor,
  88. top_k_index: torch.Tensor,
  89. top_k_weights: torch.Tensor,
  90. ) -> torch.Tensor:
  91. device = hidden_states.device
  92. num_top_k = top_k_index.size(-1)
  93. num_tokens = hidden_states.size(0)
  94. hidden_dim = hidden_states.size(-1)
  95. # Reshape for easier indexing
  96. # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k)
  97. token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
  98. sample_weights = top_k_weights.reshape(-1) # (S,)
  99. expert_ids = top_k_index.reshape(-1) # (S,)
  100. # Handle invalid expert IDs from Expert Parallelism (EP)
  101. # When EP is enabled, tokens assigned to experts on other devices are marked with sentinel value >= num_experts
  102. invalid_mask = expert_ids >= self.num_experts
  103. expert_ids = expert_ids.clamp(0, self.num_experts - 1)
  104. # Get current hidden states for selected samples
  105. selected_hidden_states = hidden_states[token_idx]
  106. # Select gate_up or just up projection weights and biases
  107. if self.has_gate:
  108. selected_weights = self.gate_up_proj[expert_ids]
  109. selected_biases = self.gate_up_proj_bias[expert_ids] if self.has_bias else None
  110. else:
  111. selected_weights = self.up_proj[expert_ids]
  112. selected_biases = self.up_proj_bias[expert_ids] if self.has_bias else None
  113. # --- Up projection per expert (batched) ---
  114. proj_out = _batched_linear(
  115. selected_hidden_states, selected_weights, bias=selected_biases, is_transposed=self.is_transposed
  116. ) # (S, 2 * intermediate_dim) or (S, intermediate_dim) depending on whether we have gating
  117. # Apply gating or activation
  118. if self.has_gate:
  119. # for gated experts we apply the custom/default gating mechanism
  120. proj_out = self._apply_gate(proj_out) # (S, intermediate_dim)
  121. else:
  122. # for non-gated experts we just apply the activation function
  123. proj_out = self.act_fn(proj_out) # (S, intermediate_dim)
  124. # Select down projection weights and biases for selected samples
  125. selected_weights = self.down_proj[expert_ids]
  126. selected_biases = self.down_proj_bias[expert_ids] if self.has_bias else None
  127. # --- Down projection per expert (batched) ---
  128. proj_out = _batched_linear(
  129. proj_out, selected_weights, bias=selected_biases, is_transposed=self.is_transposed
  130. ) # (S, hidden_dim)
  131. # Apply routing weights and zero out invalid expert contributions
  132. weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim)
  133. weighted_out.masked_fill_(invalid_mask.unsqueeze(-1), 0.0) # Zero out invalid expert contributions
  134. # Accumulate results using deterministic reshape+sum instead of index_add_
  135. # index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd
  136. # index_add_ accumulates in-place using the dtype of the output tensor (fp16/bf16)
  137. # reshape+sum accumulates in fp32 which is more stable for low precision training/inference.
  138. final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
  139. return final_hidden_states.to(hidden_states.dtype)
  140. # torch.compiler.disable does not work with fullgraph=True, so we implement a custom operator to opaque this function.
  141. # This is not "free compilation compatibility" because now inductor won't be able to optimize matmuls inside the loop,
  142. # but since the matmuls here have dynamic shapes, inductor wouldn't have been able to optimize them anyway.
  143. def _grouped_mm_fallback(input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor) -> torch.Tensor:
  144. """
  145. Fallback grouped matrix multiplication used when `torch.nn.functional.grouped_mm` and `torch._grouped_mm`
  146. are unavailable or incompatible with `torch.compile` (e.g. non-bfloat16 weights).
  147. Args:
  148. input (`torch.Tensor`): Input of shape (S, input_dim), sorted by expert id.
  149. weight (`torch.Tensor`): Expert weights of shape (num_experts, input_dim, output_dim).
  150. offs (`torch.Tensor`): Cumulative token counts per expert of shape (num_experts,).
  151. Returns:
  152. `torch.Tensor`: Output of shape (S, output_dim).
  153. """
  154. output = torch.zeros(input.size(0), weight.size(2), device=input.device, dtype=input.dtype) # (S, output_dim)
  155. start = 0
  156. # single cpu<->gpu sync point here,
  157. # avoids multiple syncs inside the loop
  158. for i, end in enumerate(offs.tolist()):
  159. if start == end:
  160. continue
  161. torch.mm(input[start:end], weight[i], out=output[start:end])
  162. start = end
  163. return output
  164. def _grouped_mm_fallback_fake(input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor) -> torch.Tensor:
  165. """Shape/dtype inference stub for `_grouped_mm_fallback` required by `torch.compile`."""
  166. assert input.dim() == 2, f"input must be 2D (S, input_dim), got shape {tuple(input.shape)}"
  167. assert weight.dim() == 3, (
  168. f"weight must be 3D (num_experts, input_dim, output_dim), got shape {tuple(weight.shape)}"
  169. )
  170. assert offs.dim() == 1, f"offs must be 1D (num_experts,), got shape {tuple(offs.shape)}"
  171. assert offs.size(0) == weight.size(0), f"offs length {offs.size(0)} must match number of experts {weight.size(0)}"
  172. assert input.size(1) == weight.size(1), (
  173. f"input_dim mismatch: input has {input.size(1)}, weight has {weight.size(1)}"
  174. )
  175. assert offs.dtype in (torch.int32, torch.int64), f"offs must be an integer tensor, got {offs.dtype}"
  176. return torch.empty(input.size(0), weight.size(2), device=input.device, dtype=input.dtype)
  177. def _grouped_mm_fallback_setup_context(ctx, inputs, output):
  178. """Saves input and weight for backward; offs is stored directly as it is a non-differentiable integer tensor."""
  179. ctx.save_for_backward(inputs[0], inputs[1])
  180. ctx.offs = inputs[2]
  181. def _grouped_mm_fallback_backward(ctx, grad_output):
  182. """Backward pass for `_grouped_mm_fallback`. Computes grad_input and grad_weight per expert group; offs has no gradient."""
  183. input, weight = ctx.saved_tensors
  184. grad_input = torch.zeros_like(input)
  185. grad_weight = torch.zeros_like(weight)
  186. start = 0
  187. # single cpu<->gpu sync point here,
  188. # avoids multiple syncs inside the loop
  189. for i, end in enumerate(ctx.offs.tolist()):
  190. if start == end:
  191. continue
  192. torch.mm(grad_output[start:end], weight[i].T, out=grad_input[start:end])
  193. torch.mm(input[start:end].T, grad_output[start:end], out=grad_weight[i])
  194. start = end
  195. return grad_input, grad_weight, None
  196. if is_torch_available():
  197. torch.library.custom_op("transformers::grouped_mm_fallback", _grouped_mm_fallback, mutates_args=())
  198. torch.library.register_fake("transformers::grouped_mm_fallback", _grouped_mm_fallback_fake)
  199. torch.library.register_autograd(
  200. "transformers::grouped_mm_fallback",
  201. _grouped_mm_fallback_backward,
  202. setup_context=_grouped_mm_fallback_setup_context,
  203. )
  204. def _can_use_grouped_mm(input: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor) -> bool:
  205. """
  206. Check if torch.nn.functional.grouped_mm or torch._grouped_mm can be used based on availability and compatibility with torch.compile.
  207. Args:
  208. input (`torch.Tensor`):
  209. Input tensor of shape (S, input_dim).
  210. weight (`torch.Tensor`):
  211. Weight tensor of shape (num_experts, input_dim, output_dim).
  212. offs (`torch.Tensor`):
  213. Offsets tensor indicating the boundaries of each group in the input tensor.
  214. Returns:
  215. `bool`: True if grouped_mm can be used, False otherwise.
  216. """
  217. if (is_torchdynamo_compiling() and weight.dtype != torch.bfloat16) or (
  218. weight.device.type == "cpu"
  219. # accept_dev=True is necessary for "+cpu"/"+xpu" etc.
  220. and is_torch_less_or_equal("2.10.0", accept_dev=True)
  221. and (weight.data_ptr() % 16 != 0 or input.data_ptr() % 16 != 0)
  222. ):
  223. # Under the following conditions we cannot use torch.grouped_mm and have to fall back:
  224. # 1. torch.grouped_mm is not supported in torch.compile / inductor with dtypes other than bf16
  225. # 2. Before PyTorch 2.11, torch.grouped_mm on CPU required 16 bytes alignment which is not
  226. # guaranteed for tensors loaded using memmap (e.g. using safetensors lazy tensor loading)
  227. # and not really necessary because the cpu path uses a fallback for-loop implementation.
  228. # issue: https://github.com/pytorch/pytorch/issues/172440
  229. return False
  230. return hasattr(torch.nn.functional, "grouped_mm") or hasattr(torch, "_grouped_mm")
  231. def _grouped_mm(
  232. input: torch.Tensor,
  233. weight: torch.Tensor,
  234. offs: torch.Tensor,
  235. ) -> torch.Tensor:
  236. """Grouped matrix multiplication dispatcher that uses torch.nn.functional.grouped_mm if available, else falls back to torch._grouped_mm.
  237. Args:
  238. input (`torch.Tensor`):
  239. Input tensor of shape (S, input_dim).
  240. weight (`torch.Tensor`):
  241. Weight tensor of shape (num_experts, input_dim, output_dim).
  242. offs (`torch.Tensor`):
  243. Offsets tensor indicating the boundaries of each group in the input tensor.
  244. Returns:
  245. `torch.Tensor`: Output tensor of shape (S, output_dim).
  246. """
  247. if _can_use_grouped_mm(input, weight, offs):
  248. # torch.nn.functional.grouped_mm and torch._grouped_mm are not autocast-enabled,
  249. # when autocast is enabled we can end up with intermediate tensors in fp32 (e.g. LayerNorm output) and weight tensors in bf16
  250. # In that case we need to cast the input to the weight dtype to avoid dtype mismatch errors.
  251. # See: https://github.com/pytorch/pytorch/issues/174763
  252. if hasattr(torch.nn.functional, "grouped_mm"):
  253. return torch.nn.functional.grouped_mm(input.to(weight.dtype), weight, offs=offs)
  254. elif hasattr(torch, "_grouped_mm"):
  255. return torch._grouped_mm(input.to(weight.dtype), weight, offs=offs)
  256. return torch.ops.transformers.grouped_mm_fallback(input, weight, offs=offs)
  257. def _grouped_linear(
  258. input: torch.Tensor,
  259. weight: torch.Tensor,
  260. offs: torch.Tensor,
  261. bias: torch.Tensor | None = None,
  262. is_transposed: bool = False,
  263. ) -> torch.Tensor:
  264. """Grouped linear layer supporting optional bias and transposed weights.
  265. Args:
  266. input (`torch.Tensor`):
  267. Input tensor of shape (S, input_dim).
  268. weight (`torch.Tensor`):
  269. Weight tensor of shape (num_experts, input_dim, output_dim) if `is_transposed`,
  270. else of shape (num_experts, output_dim, input_dim).
  271. offs (`torch.Tensor`):
  272. Offsets tensor indicating the boundaries of each group in the input tensor.
  273. bias (`torch.Tensor`, *optional*):
  274. Bias tensor of shape (num_experts, output_dim). Default is `None`.
  275. is_transposed (`bool`, *optional*, defaults to `False`):
  276. Whether the weight tensor is transposed.
  277. Returns:
  278. `torch.Tensor`: Output tensor of shape (S, output_dim).
  279. """
  280. if is_transposed:
  281. # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim)
  282. out = _grouped_mm(input, weight, offs=offs)
  283. else:
  284. # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim)
  285. out = _grouped_mm(input, weight.transpose(-2, -1), offs=offs)
  286. if bias is not None:
  287. # We should be able to pass bias to the grouped_mm call, but it's not yet supported.
  288. out = out + bias
  289. return out
  290. def grouped_mm_experts_forward(
  291. self: torch.nn.Module,
  292. hidden_states: torch.Tensor,
  293. top_k_index: torch.Tensor,
  294. top_k_weights: torch.Tensor,
  295. ) -> torch.Tensor:
  296. device = hidden_states.device
  297. num_top_k = top_k_index.size(-1)
  298. num_tokens = hidden_states.size(0)
  299. hidden_dim = hidden_states.size(-1)
  300. # Reshape for easier indexing
  301. # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k)
  302. token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,)
  303. sample_weights = top_k_weights.reshape(-1) # (S,)
  304. expert_ids = top_k_index.reshape(-1) # (S,)
  305. # Sort by expert for grouped processing
  306. perm = torch.argsort(expert_ids)
  307. inv_perm = torch.empty_like(perm)
  308. inv_perm[perm] = torch.arange(perm.size(0), device=device)
  309. expert_ids_g = expert_ids[perm]
  310. sample_weights_g = sample_weights[perm]
  311. selected_hidden_states_g = hidden_states[token_idx[perm]]
  312. # Compute offsets for grouped_mm
  313. # using histc instead of bincount to avoid cuda graph issues
  314. # With deterministic algorithms, CPU only supports float input, CUDA only supports int input.
  315. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int()
  316. tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1)
  317. offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32)
  318. # Select expert weights and biases
  319. # NOTE: We keep all experts here and rely on offsets to target the active ones.
  320. # I have already implemented a version that only passes the active experts, but
  321. # to do so I had to use torch.unique which breaks the graph capture (data-dependent).
  322. # Also there were no speedup gains from it in my experiments, even in eager mode.
  323. if self.has_gate:
  324. selected_weights = self.gate_up_proj
  325. selected_biases = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None
  326. else:
  327. selected_weights = self.up_proj
  328. selected_biases = self.up_proj_bias[expert_ids_g] if self.has_bias else None
  329. # --- Up projection per expert (grouped) ---
  330. proj_out = _grouped_linear(
  331. selected_hidden_states_g, selected_weights, offsets, bias=selected_biases, is_transposed=self.is_transposed
  332. ) # (S, 2 * intermediate_dim) or (S, intermediate_dim) depending on whether we have gating
  333. # Apply gating or activation
  334. if self.has_gate:
  335. # for gated experts we apply the custom/default gating mechanism
  336. proj_out = self._apply_gate(proj_out) # (S, intermediate_dim)
  337. else:
  338. # for non-gated experts we just apply the activation function
  339. proj_out = self.act_fn(proj_out) # (S, intermediate_dim)
  340. # Select down projection weights and biases
  341. selected_weights = self.down_proj
  342. selected_biases = self.down_proj_bias[expert_ids_g] if self.has_bias else None
  343. # --- Down projection per expert (grouped) ---
  344. proj_out = _grouped_linear(
  345. proj_out, selected_weights, offsets, bias=selected_biases, is_transposed=self.is_transposed
  346. ) # (S, hidden_dim)
  347. # Apply routing weights
  348. weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim)
  349. # Restore original order
  350. weighted_out = weighted_out[inv_perm] # (S, hidden_dim)
  351. # Accumulate results using deterministic reshape+sum instead of index_add_
  352. # index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd
  353. # index_add_ accumulates in-place using the dtype of the output tensor (fp16/bf16)
  354. # reshape+sum accumulates in fp32 which is more stable for low precision training/inference.
  355. final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1)
  356. return final_hidden_states.to(hidden_states.dtype)
  357. class ExpertsInterface(GeneralInterface):
  358. """Interface for registering custom experts forward functions."""
  359. _global_mapping = {
  360. "batched_mm": batched_mm_experts_forward,
  361. "grouped_mm": grouped_mm_experts_forward,
  362. }
  363. def get_interface(self, experts_implementation: str, default: Callable) -> Callable:
  364. """Return the requested `experts_implementation`. Also strictly check its validity, and raise if invalid."""
  365. if experts_implementation is None:
  366. logger.warning_once(
  367. "You tried to access the `ExpertsInterface` with a `config._experts_implementation` set to `None`. This "
  368. "is expected if you use an Expert Module as a standalone Module. If this is not the case, something went "
  369. "wrong with the dispatch of `config._experts_implementation`"
  370. )
  371. elif experts_implementation != "eager" and experts_implementation not in self:
  372. raise KeyError(
  373. f"`{experts_implementation}` is not a valid experts implementation registered in the `ExpertsInterface`"
  374. )
  375. return super().get(experts_implementation, default)
  376. ALL_EXPERTS_FUNCTIONS = ExpertsInterface()
  377. def _default_apply_gate(self, gate_up_out: torch.Tensor) -> torch.Tensor:
  378. """
  379. Default gating mechanism: splits the gate_up_out into gate and up parts,
  380. applies the activation function to the gate part, and multiplies it with the up part.
  381. Args:
  382. gate_up_out (`torch.Tensor`):
  383. The output tensor from the gate and up projection of shape (S, 2 * intermediate_dim).
  384. Returns:
  385. `torch.Tensor`: The gated output tensor of shape (S, intermediate_dim).
  386. """
  387. gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim)
  388. return self.act_fn(gate) * up # (S, intermediate_dim)
  389. def use_experts_implementation(
  390. experts_class: type[torch.nn.Module] | None = None,
  391. *,
  392. experts_interface: ExpertsInterface = ALL_EXPERTS_FUNCTIONS,
  393. is_transposed: bool = False,
  394. has_bias: bool = False,
  395. has_gate: bool = True,
  396. ) -> type[torch.nn.Module]:
  397. """Decorator to modify experts class to support different experts implementations.
  398. Args:
  399. experts_class (`type[torch.nn.Module]`, *optional*):
  400. The experts class to modify. If not provided, returns a decorator that can be applied to the class.
  401. experts_interface (`ExpertsInterface`, *optional*, defaults to `ALL_EXPERTS_FUNCTIONS`):
  402. The experts interface to use for dispatching the forward method.
  403. is_transposed (`bool`, *optional*, defaults to `False`):
  404. Whether the expert weights are stored in transposed format.
  405. has_bias (`bool`, *optional*, defaults to `False`):
  406. Whether the expert layers include bias terms.
  407. Returns:
  408. `type[torch.nn.Module]`: The modified experts class.
  409. """
  410. def wrapper(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]:
  411. original_init = experts_class.__init__
  412. original_forward = experts_class.forward
  413. @wraps(original_init)
  414. def __init__(self, config, *args, **kwargs):
  415. original_init(self, config, *args, **kwargs)
  416. self.config = config
  417. self.has_gate = has_gate
  418. self.has_bias = has_bias
  419. self.is_transposed = is_transposed
  420. @wraps(original_forward)
  421. def forward(self, *args, **kwargs):
  422. experts_forward = experts_interface.get_interface(self.config._experts_implementation, original_forward)
  423. return experts_forward(self, *args, **kwargs)
  424. if not hasattr(experts_class, "_apply_gate"):
  425. experts_class._apply_gate = _default_apply_gate
  426. experts_class.__init__ = __init__
  427. experts_class.forward = forward
  428. return experts_class
  429. if experts_class is not None:
  430. return wrapper(experts_class)
  431. return wrapper